Repository: spotify/luigi Branch: master Commit: 0be72e20a2e6 Files: 329 Total size: 2.5 MB Directory structure: gitextract_i5ice6bg/ ├── .coveragerc ├── .github/ │ ├── CODEOWNERS │ ├── ISSUE_TEMPLATE.md │ ├── PULL_REQUEST_TEMPLATE.md │ ├── stale.yml │ └── workflows/ │ ├── codeql.yml │ └── pythonbuild.yml ├── .gitignore ├── .readthedocs.yaml ├── CONTRIBUTING.rst ├── LICENSE ├── README.rst ├── RELEASE-PROCESS.rst ├── SECURITY.md ├── bin/ │ ├── luigi │ └── luigid ├── catalog-info.yaml ├── codecov.yml ├── doc/ │ ├── .gitignore │ ├── Makefile │ ├── central_scheduler.rst │ ├── conf.py │ ├── configuration.rst │ ├── design_and_limitations.rst │ ├── example_top_artists.rst │ ├── execution_model.rst │ ├── index.rst │ ├── logging.rst │ ├── luigi_patterns.rst │ ├── mypy.rst │ ├── parameters.rst │ ├── running_luigi.rst │ ├── tasks.rst │ └── workflows.rst ├── examples/ │ ├── __init__.py │ ├── config.toml │ ├── dynamic_requirements.py │ ├── elasticsearch_index.py │ ├── execution_summary_example.py │ ├── foo.py │ ├── foo_complex.py │ ├── ftp_experiment_outputs.py │ ├── hello_world.py │ ├── kubernetes.py │ ├── per_task_retry_policy.py │ ├── pyspark_wc.py │ ├── spark_als.py │ ├── ssh_remote_execution.py │ ├── terasort.py │ ├── top_artists.py │ ├── top_artists_spark.py │ ├── wordcount.py │ └── wordcount_hadoop.py ├── luigi/ │ ├── __init__.py │ ├── __main__.py │ ├── __version__.py │ ├── batch_notifier.py │ ├── cmdline.py │ ├── cmdline_parser.py │ ├── configuration/ │ │ ├── __init__.py │ │ ├── base_parser.py │ │ ├── cfg_parser.py │ │ ├── core.py │ │ └── toml_parser.py │ ├── contrib/ │ │ ├── __init__.py │ │ ├── azureblob.py │ │ ├── batch.py │ │ ├── beam_dataflow.py │ │ ├── bigquery.py │ │ ├── bigquery_avro.py │ │ ├── datadog_metric.py │ │ ├── dataproc.py │ │ ├── docker_runner.py │ │ ├── dropbox.py │ │ ├── ecs.py │ │ ├── esindex.py │ │ ├── external_daily_snapshot.py │ │ ├── external_program.py │ │ ├── ftp.py │ │ ├── gcp.py │ │ ├── gcs.py │ │ ├── hadoop.py │ │ ├── hadoop_jar.py │ │ ├── hdfs/ │ │ │ ├── __init__.py │ │ │ ├── abstract_client.py │ │ │ ├── clients.py │ │ │ ├── config.py │ │ │ ├── error.py │ │ │ ├── format.py │ │ │ ├── hadoopcli_clients.py │ │ │ ├── target.py │ │ │ └── webhdfs_client.py │ │ ├── hive.py │ │ ├── kubernetes.py │ │ ├── lsf.py │ │ ├── lsf_runner.py │ │ ├── mongodb.py │ │ ├── mssqldb.py │ │ ├── mysqldb.py │ │ ├── opener.py │ │ ├── pai.py │ │ ├── pig.py │ │ ├── postgres.py │ │ ├── presto.py │ │ ├── prometheus_metric.py │ │ ├── pyspark_runner.py │ │ ├── rdbms.py │ │ ├── redis_store.py │ │ ├── redshift.py │ │ ├── s3.py │ │ ├── salesforce.py │ │ ├── scalding.py │ │ ├── sge.py │ │ ├── sge_runner.py │ │ ├── simulate.py │ │ ├── spark.py │ │ ├── sparkey.py │ │ ├── sqla.py │ │ ├── ssh.py │ │ ├── target.py │ │ └── webhdfs.py │ ├── date_interval.py │ ├── db_task_history.py │ ├── event.py │ ├── execution_summary.py │ ├── format.py │ ├── freezing.py │ ├── interface.py │ ├── local_target.py │ ├── lock.py │ ├── metrics.py │ ├── mock.py │ ├── mypy.py │ ├── notifications.py │ ├── parameter.py │ ├── process.py │ ├── py.typed │ ├── retcodes.py │ ├── rpc.py │ ├── safe_extractor.py │ ├── scheduler.py │ ├── server.py │ ├── setup_logging.py │ ├── static/ │ │ └── visualiser/ │ │ ├── css/ │ │ │ ├── luigi.css │ │ │ └── tipsy.css │ │ ├── fonts/ │ │ │ └── FontAwesome.otf │ │ ├── index.html │ │ ├── js/ │ │ │ ├── graph.js │ │ │ ├── luigi.js │ │ │ ├── test/ │ │ │ │ └── graph_test.js │ │ │ ├── tipsy.js │ │ │ ├── util.js │ │ │ └── visualiserApp.js │ │ ├── lib/ │ │ │ ├── URI/ │ │ │ │ └── 1.18.2/ │ │ │ │ └── URI.js │ │ │ ├── datatables/ │ │ │ │ └── images/ │ │ │ │ └── Sorting icons.psd │ │ │ └── mustache.js │ │ ├── mockdata/ │ │ │ ├── dep_graph │ │ │ ├── fetch_error │ │ │ └── task_list │ │ └── test.html │ ├── target.py │ ├── task.py │ ├── task_history.py │ ├── task_register.py │ ├── task_status.py │ ├── templates/ │ │ ├── history.html │ │ ├── layout.html │ │ ├── menu.html │ │ ├── recent.html │ │ └── show.html │ ├── tools/ │ │ ├── __init__.py │ │ ├── deps.py │ │ ├── deps_tree.py │ │ ├── luigi_grep.py │ │ └── range.py │ ├── util.py │ └── worker.py ├── pyproject.toml ├── scripts/ │ └── ci/ │ ├── conditional_tox.sh │ ├── install_start_azurite.sh │ ├── setup_hadoop_env.sh │ └── stop_azurite.sh ├── test/ │ ├── _mysqldb_test.py │ ├── _test_ftp.py │ ├── auto_namespace_test/ │ │ ├── __init__.py │ │ └── my_namespace_test.py │ ├── batch_notifier_test.py │ ├── choice_parameter_test.py │ ├── clone_test.py │ ├── cmdline_test.py │ ├── config_env_test.py │ ├── config_toml_test.py │ ├── conftest.py │ ├── contrib/ │ │ ├── __init__.py │ │ ├── _webhdfs_test.py │ │ ├── azureblob_test.py │ │ ├── batch_test.py │ │ ├── beam_dataflow_test.py │ │ ├── bigquery_avro_test.py │ │ ├── bigquery_gcloud_test.py │ │ ├── bigquery_test.py │ │ ├── cascading_test.py │ │ ├── datadog_metric_test.py │ │ ├── dataproc_test.py │ │ ├── docker_runner_test.py │ │ ├── dropbox_test.py │ │ ├── ecs_test.py │ │ ├── esindex_test.py │ │ ├── external_daily_snapshot_test.py │ │ ├── external_program_test.py │ │ ├── gcs_test.py │ │ ├── hadoop_jar_test.py │ │ ├── hdfs/ │ │ │ └── webhdfs_client_test.py │ │ ├── hdfs_test.py │ │ ├── hive_test.py │ │ ├── kubernetes_test.py │ │ ├── lsf_test.py │ │ ├── mongo_test.py │ │ ├── mysqldb_test.py │ │ ├── opener_test.py │ │ ├── pai_test.py │ │ ├── pig_test.py │ │ ├── postgres_test.py │ │ ├── postgres_with_server_test.py │ │ ├── presto_test.py │ │ ├── prometheus_metric_test.py │ │ ├── rdbms_test.py │ │ ├── redis_test.py │ │ ├── redshift_test.py │ │ ├── s3_test.py │ │ ├── salesforce_test.py │ │ ├── scalding_test.py │ │ ├── sge_test.py │ │ ├── spark_test.py │ │ ├── sqla_test.py │ │ ├── streaming_test.py │ │ └── test_ssh.py │ ├── create_packages_archive_root/ │ │ ├── module.py │ │ └── package/ │ │ ├── __init__.py │ │ ├── submodule.py │ │ ├── submodule_with_absolute_import.py │ │ ├── submodule_without_imports.py │ │ └── subpackage/ │ │ ├── __init__.py │ │ └── submodule.py │ ├── custom_metrics_test.py │ ├── customized_run_test.py │ ├── date_interval_test.py │ ├── date_parameter_test.py │ ├── db_task_history_test.py │ ├── decorator_test.py │ ├── dict_parameter_test.py │ ├── dynamic_import_test.py │ ├── event_callbacks_test.py │ ├── execution_summary_test.py │ ├── factorial_test.py │ ├── fib_test.py │ ├── gcloud-credentials.json.enc │ ├── hdfs_client_test.py │ ├── helpers.py │ ├── helpers_test.py │ ├── import_test.py │ ├── instance_test.py │ ├── instance_wrap_test.py │ ├── interface_test.py │ ├── list_parameter_test.py │ ├── local_target_test.py │ ├── lock_test.py │ ├── metrics_test.py │ ├── mock_test.py │ ├── most_common_test.py │ ├── mypy_test.py │ ├── notifications_test.py │ ├── numerical_parameter_test.py │ ├── optional_parameter_test.py │ ├── other_module.py │ ├── parameter_test.py │ ├── priority_test.py │ ├── range_test.py │ ├── recursion_test.py │ ├── remote_scheduler_test.py │ ├── retcodes_test.py │ ├── rpc_test.py │ ├── runtests.py │ ├── safe_extractor_test.py │ ├── scheduler_api_test.py │ ├── scheduler_message_test.py │ ├── scheduler_parameter_visibilities_test.py │ ├── scheduler_test.py │ ├── scheduler_visualisation_test.py │ ├── server_test.py │ ├── set_task_name_test.py │ ├── setup_logging_test.py │ ├── simulate_test.py │ ├── subtask_test.py │ ├── target_test.py │ ├── task_bulk_complete_test.py │ ├── task_forwarded_attributes_test.py │ ├── task_history_test.py │ ├── task_progress_percentage_test.py │ ├── task_register_test.py │ ├── task_running_resources_test.py │ ├── task_serialize_test.py │ ├── task_status_message_test.py │ ├── task_test.py │ ├── test_sigpipe.py │ ├── test_ssh.py │ ├── testconfig/ │ │ ├── core-site.xml │ │ ├── log4j.properties │ │ ├── logging.cfg │ │ ├── luigi.toml │ │ ├── luigi_local.toml │ │ ├── luigi_logging.toml │ │ └── pyproject.toml │ ├── util_previous_test.py │ ├── util_test.py │ ├── visible_parameters_test.py │ ├── visualiser/ │ │ ├── __init__.py │ │ ├── phantomjs_test.js │ │ └── visualiser_test.py │ ├── worker_external_task_test.py │ ├── worker_keep_alive_test.py │ ├── worker_multiprocess_test.py │ ├── worker_parallel_scheduling_test.py │ ├── worker_scheduler_com_test.py │ ├── worker_task_process_test.py │ ├── worker_task_test.py │ ├── worker_test.py │ └── wrap_test.py └── tox.ini ================================================ FILE CONTENTS ================================================ ================================================ FILE: .coveragerc ================================================ [report] omit = luigi/mrrunner.py test/_test_time_generated_module*.py */python?.?/* */site-packages/nose/* *__init__* *test/* */.tox/* */setup.py */bin/luigidc hadoop_test.py minicluster.py [run] parallel=True concurrency=multiprocessing ================================================ FILE: .github/CODEOWNERS ================================================ # The following patterns are used to auto-assign review requests # to specific individuals. Order is important; the last matching # pattern takes the most precedence. # These owners will be the default owners for everything in # the repo. Unless a later match takes precedence, * @dlstadther @spotify/dataex # Specific files, directories, paths, or file types can be # assigned more specificially. contrib/redshift*.py @dlstadther ================================================ FILE: .github/ISSUE_TEMPLATE.md ================================================ ================================================ FILE: .github/PULL_REQUEST_TEMPLATE.md ================================================ ## Description ## Motivation and Context ## Have you tested this? If so, how? ================================================ FILE: .github/stale.yml ================================================ # Number of days of inactivity before an issue becomes stale daysUntilStale: 120 # Number of days of inactivity before a stale issue is closed daysUntilClose: 14 # Issues with these labels will never be considered stale exemptLabels: - pinned - security # Label to use when marking an issue as stale staleLabel: wontfix # Comment to post when marking an issue as stale. Set to `false` to disable markComment: > This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. If closed, you may revisit when your time allows and reopen! Thank you for your contributions. # Comment to post when closing a stale issue. Set to `false` to disable closeComment: false # Limit to only `issues` or `pulls` # only: issues ================================================ FILE: .github/workflows/codeql.yml ================================================ name: "CodeQL" on: push: branches: [ 'master' ] pull_request: # The branches below must be a subset of the branches above branches: [ 'master' ] schedule: - cron: '29 18 * * 0' jobs: analyze: name: Analyze runs-on: ubuntu-latest permissions: actions: read contents: read security-events: write strategy: fail-fast: false matrix: language: [ 'python', 'javascript' ] # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] # Use only 'java' to analyze code written in Java, Kotlin or both # Use only 'javascript' to analyze code written in JavaScript, TypeScript or both # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support steps: - name: Checkout repository uses: actions/checkout@v4 # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL uses: github/codeql-action/init@v2 with: languages: ${{ matrix.language }} # If you wish to specify custom queries, you can do so here or in a config file. # By default, queries listed here will override any specified in a config file. # Prefix the list here with "+" to use these queries and those in the config file. # For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs # queries: security-extended,security-and-quality # Autobuild attempts to build any compiled languages (C/C++, C#, Go, Java, or Swift). # If this step fails, then you should remove it and run the build manually (see below) - name: Autobuild uses: github/codeql-action/autobuild@v2 # ℹ️ Command-line programs to run using the OS shell. # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun # If the Autobuild fails above, remove it and uncomment the following three lines. # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance. # - run: | # echo "Run, Build Application using script" # ./location_of_script_within_repo/buildscript.sh - name: Perform CodeQL Analysis uses: github/codeql-action/analyze@v2 with: category: "/language:${{matrix.language}}" ================================================ FILE: .github/workflows/pythonbuild.yml ================================================ name: Build on: push: branches: - master pull_request: jobs: core: runs-on: ubuntu-22.04 strategy: matrix: include: - tox-env: py310-core - tox-env: py311-core - tox-env: py312-core - tox-env: py313-core steps: - uses: actions/checkout@v6 - name: Set up the latest version of uv uses: astral-sh/setup-uv@v7 with: enable-cache: true cache-dependency-glob: "pyproject.toml" - name: Install dependencies run: | uv tool install --python-preference only-managed --python 3.12 tox --with tox-uv - name: Build env: TOXENV: ${{ matrix.tox-env }} run: uvx --with tox-uv tox run - name: Upload coverage to Codecov uses: codecov/codecov-action@v5 with: fail_ci_if_error: true verbose: true env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} mysql: runs-on: ubuntu-22.04 strategy: matrix: include: - tox-env: py310-mysql - tox-env: py311-mysql - tox-env: py312-mysql - tox-env: py313-mysql steps: - uses: actions/checkout@v6 - name: Set up the latest version of uv uses: astral-sh/setup-uv@v7 with: enable-cache: true cache-dependency-glob: "pyproject.toml" - name: Install dependencies run: | uv tool install --python-preference only-managed --python 3.12 tox --with tox-uv - name: Setup MySQL DB run: | sudo /etc/init.d/mysql start mysql -e 'create database IF NOT EXISTS luigi_test;' -uroot -proot || true mysql -e 'create user 'travis'@'localhost';' -uroot -proot || true mysql -e 'grant all privileges ON *.* TO 'travis'@'localhost';' -uroot -proot || true - name: Build env: TOXENV: ${{ matrix.tox-env }} run: uvx --with tox-uv tox run - name: Upload coverage to Codecov uses: codecov/codecov-action@v5 with: fail_ci_if_error: true verbose: true env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} postgres: runs-on: ubuntu-22.04 services: postgres: image: postgres env: POSTGRES_USER: postgres POSTGRES_PASSWORD: postgres POSTGRES_DB: postgres ports: - 5432:5432 # Set health checks to wait until postgres has started options: >- --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 strategy: matrix: include: - tox-env: py310-postgres - tox-env: py311-postgres - tox-env: py312-postgres - tox-env: py313-postgres steps: - uses: actions/checkout@v6 - name: Set up the latest version of uv uses: astral-sh/setup-uv@v7 with: enable-cache: true cache-dependency-glob: "pyproject.toml" - name: Install dependencies run: | uv tool install --python-preference only-managed --python 3.12 tox --with tox-uv - name: Create PSQL database run: | PGPASSWORD=postgres psql -h localhost -p 5432 -c 'create database spotify;' -U postgres - name: Build env: TOXENV: ${{ matrix.tox-env }} run: uvx --with tox-uv tox run - name: Upload coverage to Codecov uses: codecov/codecov-action@v5 with: fail_ci_if_error: true verbose: true env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} base: runs-on: ubuntu-22.04 env: AWS_DEFAULT_REGION: us-east-1 AWS_ACCESS_KEY_ID: accesskey AWS_SECRET_ACCESS_KEY: secretkey strategy: matrix: include: - tox-env: py310-aws - tox-env: py311-aws - tox-env: py312-aws - tox-env: py313-aws - tox-env: py310-unixsocket OVERRIDE_SKIP_CI_TESTS: True - tox-env: py311-unixsocket OVERRIDE_SKIP_CI_TESTS: True - tox-env: py312-unixsocket OVERRIDE_SKIP_CI_TESTS: True - tox-env: py313-unixsocket OVERRIDE_SKIP_CI_TESTS: True - tox-env: py310-apache - tox-env: py311-apache - tox-env: py312-apache - tox-env: py313-apache - tox-env: py310-azureblob - tox-env: py311-azureblob - tox-env: py312-azureblob - tox-env: py313-azureblob - tox-env: py310-contrib - tox-env: py311-contrib - tox-env: py312-contrib - tox-env: py313-contrib steps: - uses: actions/checkout@v6 - name: Set up the latest version of uv uses: astral-sh/setup-uv@v7 with: enable-cache: true cache-dependency-glob: "pyproject.toml" - name: Install dependencies run: | uv tool install --python-preference only-managed --python 3.12 tox --with tox-uv - name: Build env: TOXENV: ${{ matrix.tox-env }} OVERRIDE_SKIP_CI_TESTS: ${{ matrix.OVERRIDE_SKIP_CI_TESTS }} run: uvx --with tox-uv tox run - name: Upload coverage to Codecov uses: codecov/codecov-action@v5 with: fail_ci_if_error: true verbose: true env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} others: runs-on: ubuntu-22.04 strategy: matrix: include: - tox-env: lint - tox-env: docs - tox-env: typecheck steps: - uses: actions/checkout@v6 - name: Set up the latest version of uv uses: astral-sh/setup-uv@v7 with: enable-cache: true cache-dependency-glob: "pyproject.toml" - name: Install dependencies run: | uv tool install --python-preference only-managed --python 3.12 tox --with tox-uv - name: Build env: TOXENV: ${{ matrix.tox-env }} OVERRIDE_SKIP_CI_TESTS: ${{ matrix.OVERRIDE_SKIP_CI_TESTS }} run: uvx --with tox-uv tox run ================================================ FILE: .gitignore ================================================ .coverage.* doc/api/*.rst test/gcloud-credentials.json .hypothesis/ .nicesetup client.cfg luigi.cfg hadoop_test.py minicluster.py mrrunner.py pig_property_file packages.tar # Ignore the data files data test/data examples/data Vagrantfile *.pickle *.rej *.orig # Created by https://www.gitignore.io ### Python ### # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] # C extensions *.so # Distribution / packaging .Python env/ build/ develop-eggs/ dist/ downloads/ eggs/ # NOTE : lib/ prevents inclusion of static/visualiser/lib #lib/ lib64/ parts/ sdist/ var/ *.egg-info/ .installed.cfg *.egg # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .coverage .coverage.* .cache nosetests.xml coverage.xml my_dir # Translations *.mo *.pot # Django stuff: *.log # Sphinx documentation doc/_build/ # PyBuilder target/ ### Vim ### [._]*.s[a-w][a-z] [._]s[a-w][a-z] *.un~ Session.vim .netrwhist *~ ### PyCharm ### # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm *.iml ## Directory-based project format: .idea/ # if you remove the above rule, at least ignore the following: # User-specific stuff: # .idea/workspace.xml # .idea/tasks.xml # .idea/dictionaries # Sensitive or high-churn files: # .idea/dataSources.ids # .idea/dataSources.xml # .idea/sqlDataSources.xml # .idea/dynamic.xml # .idea/uiDesigner.xml # Gradle: # .idea/gradle.xml # .idea/libraries # Mongo Explorer plugin: # .idea/mongoSettings.xml ## File-based project format: *.ipr *.iws ## Plugin-specific files: # IntelliJ out/ # mpeltonen/sbt-idea plugin .idea_modules/ # JIRA plugin atlassian-ide-plugin.xml # Crashlytics plugin (for Android Studio and IntelliJ) com_crashlytics_export_strings.xml crashlytics.properties crashlytics-build.properties ### Vagrant ### .vagrant/ ### OSX ### .DS_Store .AppleDouble .LSOverride # Icon must end with two \r Icon # Thumbnails ._* # Files that might appear on external disk .Spotlight-V100 .Trashes # Directories potentially created on remote AFP share .AppleDB .AppleDesktop Network Trash Folder Temporary Items .apdisk .python-version ================================================ FILE: .readthedocs.yaml ================================================ version: 2 build: os: ubuntu-24.04 tools: python: "3.13" jobs: pre_create_environment: - asdf plugin add uv - asdf install uv latest - asdf global uv latest create_environment: - uv venv "${READTHEDOCS_VIRTUALENV_PATH}" install: - UV_PROJECT_ENVIRONMENT="${READTHEDOCS_VIRTUALENV_PATH}" uv sync --frozen --group docs sphinx: configuration: doc/conf.py formats: - pdf - epub ================================================ FILE: CONTRIBUTING.rst ================================================ Code of conduct --------------- This project adheres to the `Open Code of Conduct `_. By participating, you are expected to honor this code. Running the tests ----------------- We are always happy to receive Pull Requests. When you open a PR, it will automatically build on Travis. So you're not strictly required to test the patch locally before submitting it. If you do want to run the tests locally you'll need to run the commands below .. code:: bash curl -LsSf https://astral.sh/uv/install.sh | sh uv tool install tox --with tox-uv You will need a ``tox --version`` of at least 4.22. .. code:: bash # These commands are pretty fast and will tell if you've # broken something major: tox run -e flake8 tox run -e py38-core # You can also test particular files for even faster iterations tox run -e py38-core -- test/rpc_test.py # The visualiser tests require phantomjs to be installed on your path tox run -e visualiser # And some of the others involve downloading and running Hadoop: tox run -e py38-cdh tox run -e py39-hdp Where ``flake8`` is the lint checking, ``py38`` is obviously Python 3.8. ``core`` are tests that do not require external components and ``cdh`` and ``hdp`` are two different hadoop distributions. For most local development it's usually enough to run the lint checking and a python version for ``core`` and let Travis run for the whole matrix. For `cdh` and `hdp`, tox will download the hadoop distribution for you. You however have to have Java installed and the `JAVA_HOME` environment variable set. For more details, check out the ``.github/workflows/pythonbuild.yml`` and ``tox.ini`` files. Writing documentation ===================== All documentation for Luigi is written in `reStructuredText/Sphinx markup `_ and are both in the code as docstrings and in `.rst`. Pull requests should come with documentation when appropriate. You verify that your documentation code compiles by running .. code:: bash tox run -e docs After that, you can check how it renders locally with your browser .. code:: bash firefox doc/_build/html/index.html ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright 2012-2021 Spotify AB Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.rst ================================================ .. figure:: https://raw.githubusercontent.com/spotify/luigi/master/doc/luigi.png :alt: Luigi Logo :align: center .. image:: https://img.shields.io/endpoint.svg?url=https%3A%2F%2Factions-badge.atrox.dev%2Fspotify%2Fluigi%2Fbadge&label=build&logo=none&%3Fref%3Dmaster&style=flat :target: https://actions-badge.atrox.dev/spotify/luigi/goto?ref=master .. image:: https://img.shields.io/codecov/c/github/spotify/luigi/master.svg?style=flat :target: https://codecov.io/gh/spotify/luigi?branch=master .. image:: https://img.shields.io/pypi/v/luigi.svg?style=flat :target: https://pypi.python.org/pypi/luigi .. image:: https://img.shields.io/pypi/l/luigi.svg?style=flat :target: https://pypi.python.org/pypi/luigi .. image:: https://readthedocs.org/projects/luigi/badge/?version=stable :target: https://luigi.readthedocs.io/en/stable/?badge=stable :alt: Documentation Status Luigi is a Python (3.10, 3.11, 3.12, 3.13 tested) package that helps you build complex pipelines of batch jobs. It handles dependency resolution, workflow management, visualization, handling failures, command line integration, and much more. Getting Started --------------- Run ``pip install luigi`` to install the latest stable version from `PyPI `_. `Documentation for the latest release `__ is hosted on readthedocs. Run ``pip install luigi[toml]`` to install Luigi with `TOML-based configs `__ support. For the bleeding edge code, ``pip install git+https://github.com/spotify/luigi.git``. `Bleeding edge documentation `__ is also available. Background ---------- The purpose of Luigi is to address all the plumbing typically associated with long-running batch processes. You want to chain many tasks, automate them, and failures *will* happen. These tasks can be anything, but are typically long running things like `Hadoop `_ jobs, dumping data to/from databases, running machine learning algorithms, or anything else. There are other software packages that focus on lower level aspects of data processing, like `Hive `__, `Pig `_, or `Cascading `_. Luigi is not a framework to replace these. Instead it helps you stitch many tasks together, where each task can be a `Hive query `__, a `Hadoop job in Java `_, a `Spark job in Scala or Python `_, a Python snippet, `dumping a table `_ from a database, or anything else. It's easy to build up long-running pipelines that comprise thousands of tasks and take days or weeks to complete. Luigi takes care of a lot of the workflow management so that you can focus on the tasks themselves and their dependencies. You can build pretty much any task you want, but Luigi also comes with a *toolbox* of several common task templates that you use. It includes support for running `Python mapreduce jobs `_ in Hadoop, as well as `Hive `__, and `Pig `__, jobs. It also comes with `file system abstractions for HDFS `_, and local files that ensures all file system operations are atomic. This is important because it means your data pipeline will not crash in a state containing partial data. Visualiser page --------------- The Luigi server comes with a web interface too, so you can search and filter among all your tasks. .. figure:: https://raw.githubusercontent.com/spotify/luigi/master/doc/visualiser_front_page.png :alt: Visualiser page Dependency graph example ------------------------ Just to give you an idea of what Luigi does, this is a screen shot from something we are running in production. Using Luigi's visualiser, we get a nice visual overview of the dependency graph of the workflow. Each node represents a task which has to be run. Green tasks are already completed whereas yellow tasks are yet to be run. Most of these tasks are Hadoop jobs, but there are also some things that run locally and build up data files. .. figure:: https://raw.githubusercontent.com/spotify/luigi/master/doc/user_recs.png :alt: Dependency graph Philosophy ---------- Conceptually, Luigi is similar to `GNU Make `_ where you have certain tasks and these tasks in turn may have dependencies on other tasks. There are also some similarities to `Oozie `_ and `Azkaban `_. One major difference is that Luigi is not just built specifically for Hadoop, and it's easy to extend it with other kinds of tasks. Everything in Luigi is in Python. Instead of XML configuration or similar external data files, the dependency graph is specified *within Python*. This makes it easy to build up complex dependency graphs of tasks, where the dependencies can involve date algebra or recursive references to other versions of the same task. However, the workflow can trigger things not in Python, such as running `Pig scripts `_ or `scp'ing files `_. Who uses Luigi? --------------- We use Luigi internally at `Spotify `_ to run thousands of tasks every day, organized in complex dependency graphs. Most of these tasks are Hadoop jobs. Luigi provides an infrastructure that powers all kinds of stuff including recommendations, toplists, A/B test analysis, external reports, internal dashboards, etc. Since Luigi is open source and without any registration walls, the exact number of Luigi users is unknown. But based on the number of unique contributors, we expect hundreds of enterprises to use it. Some users have written blog posts or held presentations about Luigi: * `Spotify `_ `(presentation, 2014) `__ * `Foursquare `_ `(presentation, 2013) `__ * `Mortar Data (Datadog) `_ `(documentation / tutorial) `__ * `Stripe `_ `(presentation, 2014) `__ * `Buffer `_ `(blog, 2014) `__ * `SeatGeek `_ `(blog, 2015) `__ * `Treasure Data `_ `(blog, 2015) `__ * `Growth Intelligence `_ `(presentation, 2015) `__ * `AdRoll `_ `(blog, 2015) `__ * 17zuoye `(presentation, 2015) `__ * `Custobar `_ `(presentation, 2016) `__ * `Blendle `_ `(presentation) `__ * `TrustYou `_ `(presentation, 2015) `__ * `Groupon `_ / `OrderUp `_ `(alternative implementation) `__ * `Red Hat - Marketing Operations `_ `(blog, 2017) `__ * `GetNinjas `_ `(blog, 2017) `__ * `voyages-sncf.com `_ `(presentation, 2017) `__ * `Open Targets `_ `(blog, 2017) `__ * `Leipzig University Library `_ `(presentation, 2016) `__ / `(project) `__ * `Synetiq `_ `(presentation, 2017) `__ * `Glossier `_ `(blog, 2018) `__ * `Data Revenue `_ `(blog, 2018) `_ * `Uppsala University `_ `(tutorial) `_ / `(presentation, 2015) `_ / `(slides, 2015) `_ / `(poster, 2015) `_ / `(paper, 2016) `_ / `(project) `_ * `GIPHY `_ `(blog, 2019) `__ * `xtream `__ `(blog, 2019) `__ * `CIAN `__ `(presentation, 2019) `__ Some more companies are using Luigi but haven't had a chance yet to write about it: * `Schibsted `_ * `enbrite.ly `_ * `Dow Jones / The Wall Street Journal `_ * `Hotels.com `_ * `Newsela `_ * `Squarespace `_ * `OAO `_ * `Grovo `_ * `Weebly `_ * `Deloitte `_ * `Stacktome `_ * `LINX+Neemu+Chaordic `_ * `Foxberry `_ * `Okko `_ * `ISVWorld `_ * `Big Data `_ * `Movio `_ * `Bonnier News `_ * `Starsky Robotics `_ * `BaseTIS `_ * `Hopper `_ * `VOYAGE GROUP/Zucks `_ * `Textpert `_ * `Tracktics `_ * `Whizar `_ * `xtream `__ * `Skyscanner `_ * `Jodel `_ * `Mekar `_ * `M3 `_ * `Assist Digital `_ * `Meltwater `_ * `DevSamurai `_ * `Veridas `_ * `Aidentified `_ We're more than happy to have your company added here. Just send a PR on GitHub. External links -------------- * `Mailing List `_ for discussions and asking questions. (Google Groups) * `Releases `_ (PyPI) * `Source code `_ (GitHub) * `Hubot Integration `_ plugin for Slack, Hipchat, etc (GitHub) Authors ------- Luigi was built at `Spotify `_, mainly by `Erik Bernhardsson `_ and `Elias Freider `_. `Many other people `_ have contributed since open sourcing in late 2012. `Arash Rouhani `_ was the chief maintainer from 2015 to 2019, and now Spotify's Data Team maintains Luigi. ================================================ FILE: RELEASE-PROCESS.rst ================================================ For maintainers of Luigi, who have push access to pypi. Here's how you upload Luigi to pypi. #. Make sure [uv](https://github.com/astral-sh/uv) is installed ``curl -LsSf https://astral.sh/uv/install.sh | sh``. #. Update version number in `luigi/__version__.py`. #. Commit, perhaps simply with a commit message like ``Version x.y.z``. #. Push to GitHub at [spotify/luigi](https://github.com/spotify/luigi). #. Clean up previous distributions by executing ``rm -rf dist``. #. Build a source distribution by executing ``uv build``. #. Set pypi token on environment variable ``export UV_PUBLISH_TOKEN="LUIGI_PYPI_TOKEN_HERE"``. #. Upload to pypi by executing ``uv publish``. #. Add a tag on github (https://github.com/spotify/luigi/releases), including a handwritten changelog, possibly inspired from previous notes. Currently, Luigi is not released on any particular schedule and it is not strictly abiding semantic versioning. Whenever possible, bump major version when you make incompatible API changes, minor version when you add functionality in a backwards compatible manner, and patch version when you make backwards compatible bug fixes. ================================================ FILE: SECURITY.md ================================================ # Security Policy ## Reporting a Vulnerability Please report sensitive security issues via Spotify's [bug-bounty program](https://hackerone.com/spotify) by following this [instruction](https://docs.hackerone.com/programs/security-page.html), rather than GitHub. ================================================ FILE: bin/luigi ================================================ #!/usr/bin/env python import sys import warnings import luigi.cmdline def main(argv): warnings.warn("'bin/luigi' has moved to console script 'luigi'", DeprecationWarning) luigi.cmdline.luigi_run(argv) if __name__ == '__main__': main(sys.argv[1:]) ================================================ FILE: bin/luigid ================================================ #!/usr/bin/env python import sys import warnings import luigi.cmdline def main(argv): warnings.warn("'bin/luigid' has moved to console script 'luigid'", DeprecationWarning) luigi.cmdline.luigid(argv) if __name__ == '__main__': main(sys.argv[1:]) ================================================ FILE: catalog-info.yaml ================================================ apiVersion: backstage.io/v1alpha1 kind: Component metadata: name: luigi spec: type: library owner: dataex ================================================ FILE: codecov.yml ================================================ codecov: require_ci_to_pass: true notify: wait_for_ci: true coverage: precision: 2 round: down range: "50...70" status: project: default: false # disable the default status that measures entire project core: target: 90% paths: - "luigi/*.py" patch: default: target: 50% if_no_uploads: error changes: default: informational: true ignore: - "examples/" - "luigi/tools" # These are tested as actual run commands without coverage # List modules who's tests are not run by CI or are run in a subprocesses (like on cluster). - "luigi/contrib/beam_dataflow.py" - "luigi/contrib/bigquery.py" - "luigi/contrib/bigquery_avro.py" - "luigi/contrib/dataproc.py" - "luigi/contrib/dropbox.py" - "luigi/contrib/ftp.py" - "luigi/contrib/gcs.py" - "luigi/contrib/hadoop.py" - "luigi/contrib/hdfs/" - "luigi/contrib/kubernetes.py" - "luigi/contrib/mrrunner.py" - "luigi/contrib/sparkey.py" - "luigi/contrib/webhdfs.py" # For luigi we do not want any comments comment: false ================================================ FILE: doc/.gitignore ================================================ _static _build _templates ================================================ FILE: doc/Makefile ================================================ # Makefile for Sphinx documentation # # You can set these variables from the command line. SPHINXOPTS = SPHINXBUILD = sphinx-build PAPER = BUILDDIR = _build # User-friendly check for sphinx-build ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) endif # Internal variables. PAPEROPT_a4 = -D latex_paper_size=a4 PAPEROPT_letter = -D latex_paper_size=letter ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . # the i18n builder cannot share the environment and doctrees with the others I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext help: @echo "Please use \`make ' where is one of" @echo " html to make standalone HTML files" @echo " dirhtml to make HTML files named index.html in directories" @echo " singlehtml to make a single large HTML file" @echo " pickle to make pickle files" @echo " json to make JSON files" @echo " htmlhelp to make HTML files and a HTML help project" @echo " qthelp to make HTML files and a qthelp project" @echo " devhelp to make HTML files and a Devhelp project" @echo " epub to make an epub" @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" @echo " latexpdf to make LaTeX files and run them through pdflatex" @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" @echo " text to make text files" @echo " man to make manual pages" @echo " texinfo to make Texinfo files" @echo " info to make Texinfo files and run them through makeinfo" @echo " gettext to make PO message catalogs" @echo " changes to make an overview of all changed/added/deprecated items" @echo " xml to make Docutils-native XML files" @echo " pseudoxml to make pseudoxml-XML files for display purposes" @echo " linkcheck to check all external links for integrity" @echo " doctest to run all doctests embedded in the documentation (if enabled)" clean: rm -rf $(BUILDDIR)/* html: $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html @echo @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." dirhtml: $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml @echo @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." singlehtml: $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml @echo @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." pickle: $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle @echo @echo "Build finished; now you can process the pickle files." json: $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json @echo @echo "Build finished; now you can process the JSON files." htmlhelp: $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp @echo @echo "Build finished; now you can run HTML Help Workshop with the" \ ".hhp project file in $(BUILDDIR)/htmlhelp." qthelp: $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp @echo @echo "Build finished; now you can run "qcollectiongenerator" with the" \ ".qhcp project file in $(BUILDDIR)/qthelp, like this:" @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/Luigi.qhcp" @echo "To view the help file:" @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/Luigi.qhc" devhelp: $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp @echo @echo "Build finished." @echo "To view the help file:" @echo "# mkdir -p $$HOME/.local/share/devhelp/Luigi" @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/Luigi" @echo "# devhelp" epub: $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub @echo @echo "Build finished. The epub file is in $(BUILDDIR)/epub." latex: $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex @echo @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." @echo "Run \`make' in that directory to run these through (pdf)latex" \ "(use \`make latexpdf' here to do that automatically)." latexpdf: $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex @echo "Running LaTeX files through pdflatex..." $(MAKE) -C $(BUILDDIR)/latex all-pdf @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." latexpdfja: $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex @echo "Running LaTeX files through platex and dvipdfmx..." $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." text: $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text @echo @echo "Build finished. The text files are in $(BUILDDIR)/text." man: $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man @echo @echo "Build finished. The manual pages are in $(BUILDDIR)/man." texinfo: $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo @echo @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." @echo "Run \`make' in that directory to run these through makeinfo" \ "(use \`make info' here to do that automatically)." info: $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo @echo "Running Texinfo files through makeinfo..." make -C $(BUILDDIR)/texinfo info @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." gettext: $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale @echo @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." changes: $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes @echo @echo "The overview file is in $(BUILDDIR)/changes." linkcheck: $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck @echo @echo "Link check complete; look for any errors in the above output " \ "or in $(BUILDDIR)/linkcheck/output.txt." doctest: $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest @echo "Testing of doctests in the sources finished, look at the " \ "results in $(BUILDDIR)/doctest/output.txt." xml: $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml @echo @echo "Build finished. The XML files are in $(BUILDDIR)/xml." pseudoxml: $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml @echo @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." ================================================ FILE: doc/central_scheduler.rst ================================================ Using the Central Scheduler --------------------------- While the ``--local-scheduler`` flag is useful for development purposes, it's not recommended for production usage. The centralized scheduler serves two purposes: - Make sure two instances of the same task are not running simultaneously - Provide visualization of everything that's going on. Note that the central scheduler does not execute anything for you or help you with job parallelization. For running tasks periodically, the easiest thing to do is to trigger a Python script from cron or from a continuously running process. There is no central process that automatically triggers jobs. This model may seem limited, but we believe that it makes things far more intuitive and easy to understand. .. figure:: dependency_graph.png :alt: Dependency graph in the visualiser The luigid server ~~~~~~~~~~~~~~~~~ To run the server as a daemon run: .. code-block:: console $ luigid --background --pidfile --logdir --state-path Note that this requires ``python-daemon``. By default, the server starts on AF_INET and AF_INET6 port ``8082`` (which can be changed with the ``--port`` flag) and listens on all IPs. To change the default behavior of listening on all IPs, pass the ``--address`` flag and the IP address to listen on. To use an AF_UNIX socket use the ``--unix-socket`` flag. For a full list of configuration options and defaults, see the :ref:`scheduler configuration section `. Note that ``luigid`` uses the same configuration files as the Luigi client (i.e. ``luigi.cfg`` or ``/etc/luigi/client.cfg`` by default). .. _TaskHistory: Enabling Task History ~~~~~~~~~~~~~~~~~~~~~ Task History is an experimental feature in which additional information about tasks that have been executed are recorded in a relational database for historical analysis. This information is exposed via the Central Scheduler at ``/history``. To enable the task history, specify ``record_task_history = True`` in the ``[scheduler]`` section of ``luigi.cfg`` and specify ``db_connection`` under ``[task_history]``. The ``db_connection`` string is used to configure the `SQLAlchemy engine `_. When starting up, ``luigid`` will create all the necessary tables using `create_all `_. Example configuration .. code:: ini [scheduler] record_task_history = True state_path = /usr/local/var/luigi-state.pickle [task_history] db_connection = sqlite:////usr/local/var/luigi-task-hist.db The task history has the following pages: * ``/history`` a reverse-cronological listing of runs from the past 24 hours. Example screenshot: .. figure:: history.png :alt: Recent history screenshot * ``/history/by_id/{id}`` detailed information about a run, including: parameter values, the host on which it ran, and timing information. Example screenshot: .. figure:: history_by_id.png :alt: By id screenshot * ``/history/by_name/{name}`` a listing of all runs of a task with the given task ``{name}``. Example screenshot: .. figure:: history_by_name.png :alt: By name screenshot * ``/history/by_params/{name}?data=params`` a listing of all runs of the task ``{name}`` restricted to runs with ``params`` matching the given history. The ``params`` is a json blob describing the parameters, e.g. ``data={"foo": "bar"}`` looks for a task with ``foo=bar``. * ``/history/by_task_id/{task_id}`` the latest run of a task given the ``{task_id}``. It is different from just ``{id}`` and is a derivative of ``params``. It is available via ``{task_id}`` property of a ``luigi.Task`` instance or via `luigi.task.task_id_str `_. This kind of representation is useful for concisely recording URLs in a history tree. Example screenshot: .. figure:: history_by_task_id.png :alt: By task_id screenshot ================================================ FILE: doc/conf.py ================================================ # -*- coding: utf-8 -*- # # Luigi documentation build configuration file, created by # sphinx-quickstart on Sat Feb 8 00:56:43 2014. # # This file is execfile()d with the current directory set to its # containing dir. # # Note that not all possible configuration values are present in this # autogenerated file. # # All configuration values have a default; values that are commented out # serve to show the default. import sys import os import datetime from importlib.metadata import Distribution try: import luigi import luigi.parameter def parameter_repr(self): """ When building documentation, we want Parameter objects to show their description in a nice way """ significance = 'Insignificant ' if not self.significant else '' class_name = self.__class__.__name__ has_default = self._default != luigi.parameter._no_value default = ' (defaults to {})'.format(self._default) if has_default else '' description = (': ' + self.description if self.description else '') return significance + class_name + default + description luigi.parameter.Parameter.__repr__ = parameter_repr def assertIn(needle, haystack): """ We test repr of Parameter objects, since it'll be used for readthedocs """ assert needle in haystack # TODO: find a better place to put this! assertIn('IntParameter', repr(luigi.IntParameter())) assertIn('defaults to 37', repr(luigi.IntParameter(default=37))) assertIn('hi mom', repr(luigi.IntParameter(description='hi mom'))) assertIn('Insignificant BoolParameter', repr(luigi.BoolParameter(significant=False))) except ImportError: pass # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. sys.path.insert(0, os.path.abspath(os.path.pardir)) # append the __init__ to class definitions autoclass_content = 'both' # -- General configuration ------------------------------------------------ # If your documentation needs a minimal Sphinx version, state it here. needs_sphinx = '9.0' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ 'sphinx.ext.autodoc', 'sphinx.ext.viewcode', 'sphinx.ext.autosummary', ] # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] # The suffix of source filenames. source_suffix = '.rst' # The encoding of source files. #source_encoding = 'utf-8-sig' # The master toctree document. master_doc = 'index' # General information about the project. project = u'Luigi' authors = u"The Luigi Authors" copyright = u"2011-{}, {}".format(datetime.datetime.now().year, authors) # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # __version__ = Distribution.from_name('luigi').version # assume luigi is already installed # The short X.Y version. version = ".".join(__version__.split(".")[0:2]) # The full version, including alpha/beta/rc tags. release = __version__ # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. #language = None # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: #today = '' # Else, today_fmt is used as the format for a strftime call. #today_fmt = '%B %d, %Y' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. exclude_patterns = ['_build', 'README.rst'] # The reST default role (used for this markup: `text`) to use for all # documents. #default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. #add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). #add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. #show_authors = False # The name of the Pygments (syntax highlighting) style to use. pygments_style = 'sphinx' # A list of ignored prefixes for module index sorting. #modindex_common_prefix = [] # If true, keep warnings as "system message" paragraphs in the built documents. #keep_warnings = False autodoc_default_options = {'members': True, 'undoc-members': True} autosummary_generate = True autodoc_member_order = 'bysource' # -- Options for HTML output ---------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. html_theme = 'sphinx_rtd_theme' # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. #html_theme_options = {} # Add any paths that contain custom themes here, relative to this directory. #html_theme_path = [] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". #html_title = None # A shorter title for the navigation bar. Default is the same as html_title. #html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. html_logo = 'luigi.png' # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. #html_favicon = None # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". #html_static_path = ['_static'] # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied # directly to the root of the documentation. #html_extra_path = [] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. #html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. #html_use_smartypants = True # Custom sidebar templates, maps document names to template names. #html_sidebars = {} # Additional templates that should be rendered to pages, maps page names to # template names. #html_additional_pages = {} # If false, no module index is generated. #html_domain_indices = True # If false, no index is generated. #html_use_index = True # If true, the index is split into individual pages for each letter. #html_split_index = False # If true, links to the reST sources are added to the pages. #html_show_sourcelink = True # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. #html_show_sphinx = True # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. #html_show_copyright = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. #html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). #html_file_suffix = None # Output file base name for HTML help builder. htmlhelp_basename = 'Luigidoc' # -- Options for LaTeX output --------------------------------------------- latex_elements = { # The paper size ('letterpaper' or 'a4paper'). #'papersize': 'letterpaper', # The font size ('10pt', '11pt' or '12pt'). #'pointsize': '10pt', # Additional stuff for the LaTeX preamble. #'preamble': '', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ ('index', 'Luigi.tex', u'Luigi Documentation', authors, 'manual'), ] # The name of an image file (relative to this directory) to place at the top of # the title page. #latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. #latex_use_parts = False # If true, show page references after internal links. #latex_show_pagerefs = False # If true, show URL addresses after external links. #latex_show_urls = False # Documents to append as an appendix to all manuals. #latex_appendices = [] # If false, no module index is generated. #latex_domain_indices = True # -- Options for manual page output --------------------------------------- # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ ('index', 'luigi', u'Luigi Documentation', [authors], 1) ] # If true, show URL addresses after external links. #man_show_urls = False # -- Options for Texinfo output ------------------------------------------- # Grouping the document tree into Texinfo files. List of tuples # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ ('index', 'Luigi', u'Luigi Documentation', authors, 'Luigi', 'One line description of project.', 'Miscellaneous'), ] # Documents to append as an appendix to all manuals. #texinfo_appendices = [] # If false, no module index is generated. #texinfo_domain_indices = True # How to display URL addresses: 'footnote', 'no', or 'inline'. #texinfo_show_urls = 'footnote' # If true, do not generate a @detailmenu in the "Top" node's menu. #texinfo_no_detailmenu = False autodoc_mock_imports = ["mypy"] # sphinx-apidoc --separate generates individual RST files not referenced by any toctree; # suppress the resulting warnings since this is expected behaviour. suppress_warnings = ['toc.not_included'] # Some regression introduced # https://github.com/sphinx-doc/sphinx/issues/2330 # https://github.com/spotify/luigi/pull/1555 highlight_language = "python" ================================================ FILE: doc/configuration.rst ================================================ Configuration ============= All configuration can be done by adding configuration files. Supported config parsers: * ``cfg`` (default), based on Python's standard ConfigParser_. Values may refer to environment variables using ``${ENVVAR}`` syntax. * ``toml`` .. _ConfigParser: https://docs.python.org/3/library/configparser.html You can choose right parser via ``LUIGI_CONFIG_PARSER`` environment variable. For example, ``LUIGI_CONFIG_PARSER=toml``. Default (cfg) parser are looked for in: * ``/etc/luigi/client.cfg`` (deprecated) * ``/etc/luigi/luigi.cfg`` * ``client.cfg`` (deprecated) * ``luigi.cfg`` * ``LUIGI_CONFIG_PATH`` environment variable `TOML `_ parser are looked for in: * ``/etc/luigi/luigi.toml`` * ``luigi.toml`` * ``LUIGI_CONFIG_PATH`` environment variable Both config lists increase in priority (from low to high). The order only matters in case of key conflicts (see docs for ConfigParser.read_). These files are meant for both the client and ``luigid``. If you decide to specify your own configuration you should make sure that both the client and ``luigid`` load it properly. .. _ConfigParser.read: https://docs.python.org/3/library/configparser.html#configparser.ConfigParser.read The config file is broken into sections, each controlling a different part of the config. Example cfg config: .. code:: ini [hadoop] version=cdh4 streaming_jar=/usr/lib/hadoop-xyz/hadoop-streaming-xyz-123.jar [core] scheduler_host=luigi-host.mycompany.foo Example toml config: .. code:: python [hadoop] version = "cdh4" streaming_jar = "/usr/lib/hadoop-xyz/hadoop-streaming-xyz-123.jar" [core] scheduler_host = "luigi-host.mycompany.foo" Also see `examples/config.toml `_ for more complex example. .. _ParamConfigIngestion: Parameters from config Ingestion -------------------------------- All parameters can be overridden from configuration files. For instance if you have a Task definition: .. code:: python class DailyReport(luigi.contrib.hadoop.JobTask): date = luigi.DateParameter(default=datetime.date.today()) # ... Then you can override the default value for ``DailyReport().date`` by providing it in the configuration: .. code:: ini [DailyReport] date=2012-01-01 .. _ConfigClasses: Configuration classes ********************* Using the :ref:`ParamConfigIngestion` method, we derive the conventional way to do global configuration. Imagine this configuration. .. code:: ini [mysection] option=hello intoption=123 We can create a :py:class:`~luigi.Config` class: .. code:: python import luigi # Config classes should be camel cased class mysection(luigi.Config): option = luigi.Parameter(default='world') intoption = luigi.IntParameter(default=555) mysection().option mysection().intoption Configurable options -------------------- Luigi comes with a lot of configurable options. Below, we describe each section and the parameters available within it. [core] ------ These parameters control core Luigi behavior, such as error e-mails and interactions between the worker and scheduler. autoload_range .. versionadded:: 2.8.11 If false, prevents range tasks from autoloading. They can still be loaded using ``--module luigi.tools.range``. Defaults to true. Setting this to true explicitly disables the deprecation warning. default_scheduler_host Hostname of the machine running the scheduler. Defaults to localhost. default_scheduler_port Port of the remote scheduler api process. Defaults to 8082. default_scheduler_url Full path to remote scheduler. Defaults to ``http://localhost:8082/``. For TLS support use the URL scheme: ``https``, example: ``https://luigi.example.com:443/`` (Note: you will have to terminate TLS using an HTTP proxy) You can also use this to connect to a local Unix socket using the non-standard URI scheme: ``http+unix`` example: ``http+unix://%2Fvar%2Frun%2Fluigid%2Fluigid.sock/`` hdfs_tmp_dir Base directory in which to store temporary files on hdfs. Defaults to tempfile.gettempdir() history_filename If set, specifies a filename for Luigi to write stuff (currently just job id) to in mapreduce job's output directory. Useful in a configuration where no history is stored in the output directory by Hadoop. log_level The default log level to use when no logging_conf_file is set. Must be a valid name of a `Python log level `_. Default is ``DEBUG``. logging_conf_file Location of the logging configuration file. no_configure_logging If true, logging is not configured. Defaults to false. parallel_scheduling If true, the scheduler will compute complete functions of tasks in parallel using multiprocessing. This can significantly speed up scheduling, but requires that all tasks can be pickled. Defaults to false. parallel_scheduling_processes The number of processes to use for parallel scheduling. If not specified the default number of processes will be the total number of CPUs available. rpc_connect_timeout Number of seconds to wait before timing out when making an API call. Defaults to 10.0 rpc_retry_attempts The maximum number of retries to connect the central scheduler before giving up. Defaults to 3 rpc_retry_wait Number of seconds to wait before the next attempt will be started to connect to the central scheduler between two retry attempts. Defaults to 30 [cors] ------ .. versionadded:: 2.8.0 These parameters control ``/api/`` ``CORS`` behaviour (see: `W3C Cross-Origin Resource Sharing `_). enabled Enables CORS support. Defaults to false. allowed_origins A list of allowed origins. Used only if ``allow_any_origin`` is false. Configure in JSON array format, e.g. ["foo", "bar"]. Defaults to empty. allow_any_origin Accepts requests from any origin. Defaults to false. allow_null_origin Allows the request to set ``null`` value of the ``Origin`` header. Defaults to false. max_age Content of ``Access-Control-Max-Age``. Defaults to 86400 (24 hours). allowed_methods Content of ``Access-Control-Allow-Methods``. Defaults to ``GET, OPTIONS``. allowed_headers Content of ``Access-Control-Allow-Headers``. Defaults to ``Accept, Content-Type, Origin``. exposed_headers Content of ``Access-Control-Expose-Headers``. Defaults to empty string (will NOT be sent as a response header). allow_credentials Indicates that the actual request can include user credentials. Defaults to false. .. _worker-config: [worker] -------- These parameters control Luigi worker behavior. count_uniques If true, workers will only count unique pending jobs when deciding whether to stay alive. So if a worker can't get a job to run and other workers are waiting on all of its pending jobs, the worker will die. ``worker_keep_alive`` must be ``true`` for this to have any effect. Defaults to false. keep_alive If true, workers will stay alive when they run out of jobs to run, as long as they have some pending job waiting to be run. Defaults to false. ping_interval Number of seconds to wait between pinging scheduler to let it know that the worker is still alive. Defaults to 1.0. task_limit .. versionadded:: 1.0.25 Maximum number of tasks to schedule per invocation. Upon exceeding it, the worker will issue a warning and proceed with the workflow obtained thus far. Prevents incidents due to spamming of the scheduler, usually accidental. Default: no limit. task_process_context An optional setting allowing Luigi to import a custom context manager used to wrap the execution of tasks' run methods. Default: no context manager. timeout .. versionadded:: 1.0.20 Number of seconds after which to kill a task which has been running for too long. This provides a default value for all tasks, which can be overridden by setting the ``worker_timeout`` property in any task. Default value is 0, meaning no timeout. wait_interval Number of seconds for the worker to wait before asking the scheduler for another job after the scheduler has said that it does not have any available jobs. wait_jitter Duration of jitter to add to the worker wait interval such that the multiple workers do not ask the scheduler for another job at the same time, in seconds. Default: 5.0 max_keep_alive_idle_duration .. versionadded:: 2.8.4 Maximum duration in seconds to keep worker alive while in idle state. Default: 0 (Indefinitely) max_reschedules The maximum number of times that a job can be automatically rescheduled by a worker before it will stop trying. Workers will reschedule a job if it is found to not be done when attempting to run a dependent job. This defaults to 1. retry_external_tasks If true, incomplete external tasks (i.e. tasks where the ``run()`` method is NotImplemented) will be retested for completion while Luigi is running. This means that if external dependencies are satisfied after a workflow has started, any tasks dependent on that resource will be eligible for running. Note: Every time the task remains incomplete, it will count as FAILED, so normal retry logic applies (see: ``retry_count`` and ``retry_delay``). This setting works best with ``worker_keep_alive: true``. If false, external tasks will only be evaluated when Luigi is first invoked. In this case, Luigi will not check whether external dependencies are satisfied while a workflow is in progress, so dependent tasks will remain PENDING until the workflow is reinvoked. Defaults to false for backwards compatibility. no_install_shutdown_handler By default, workers will stop requesting new work and finish running pending tasks after receiving a ``SIGUSR1`` signal. This provides a hook for gracefully shutting down workers that are in the process of running (potentially expensive) tasks. If set to true, Luigi will NOT install this shutdown hook on workers. Note this hook does not work on Windows operating systems, or when jobs are launched outside the main execution thread. Defaults to false. send_failure_email Controls whether the worker will send e-mails on task and scheduling failures. If set to false, workers will only send e-mails on framework errors during scheduling and all other e-mail must be handled by the scheduler. Defaults to true. check_unfulfilled_deps If true, the worker checks for completeness of dependencies before running a task. In case unfulfilled dependencies are detected, an exception is raised and the task will not run. This mechanism is useful to detect situations where tasks do not create their outputs properly, or when targets were removed after the dependency tree was built. It is recommended to disable this feature only when the completeness checks are known to be bottlenecks, e.g. when the ``exists()`` calls of the dependencies' outputs are resource-intensive. Defaults to true. force_multiprocessing By default, luigi uses multiprocessing when *more than one* worker process is requested. When set to true, multiprocessing is used independent of the number of workers. Defaults to false. check_complete_on_run By default, luigi tasks are marked as 'done' when they finish running without raising an error. When set to true, tasks will also verify that their outputs exist when they finish running, and will fail immediately if the outputs are missing. Defaults to false. cache_task_completion By default, luigi task processes might check the completion status multiple times per task which is a safe way to avoid potential inconsistencies. For tasks with many dynamic dependencies, yielded in multiple stages, this might become expensive, e.g. in case the per-task completion check entails remote resources. When set to true, completion checks are cached so that tasks declared as complete once are not checked again. Defaults to false. [elasticsearch] --------------- These parameters control use of elasticsearch marker_index Defaults to "update_log". marker_doc_type Defaults to "entry". [email] ------- General parameters force_send If true, e-mails are sent in all run configurations (even if stdout is connected to a tty device). Defaults to False. format Type of e-mail to send. Valid values are "plain", "html" and "none". When set to html, tracebacks are wrapped in
 tags to get fixed-
  width font. When set to none, no e-mails will be sent.

  Default value is plain.

method
  Valid values are "smtp", "sendgrid", "ses" and "sns". SES and SNS are
  services of Amazon web services. SendGrid is an email delivery service.
  The default value is "smtp".

  In order to send messages through Amazon SNS or SES set up your AWS
  config files or run Luigi on an EC2 instance with proper instance
  profile.

  In order to use sendgrid, fill in your sendgrid API key in the
  `[sendgrid]`_ section.

  In order to use smtp, fill in the appropriate fields in the `[smtp]`_
  section.

prefix
  Optional prefix to add to the subject line of all e-mails. For
  example, setting this to "[LUIGI]" would change the subject line of an
  e-mail from "Luigi: Framework error" to "[LUIGI] Luigi: Framework
  error"

receiver
  Recipient of all error e-mails. If this is not set, no error e-mails
  are sent when Luigi crashes unless the crashed job has owners set. If
  Luigi is run from the command line, no e-mails will be sent unless
  output is redirected to a file.

  Set it to SNS Topic ARN if you want to receive notifications through
  Amazon SNS. Make sure to set method to sns in this case too.

sender
  User name in from field of error e-mails.
  Default value: luigi-client@

traceback_max_length
  Maximum length for traceback included in error email. Default is 5000.


[batch_email]
----------------

Parameters controlling the contents of batch notifications sent from the
scheduler

email_interval
  Number of minutes between e-mail sends. Making this larger results in
  fewer, bigger e-mails.
  Defaults to 60.

batch_mode
  Controls how tasks are grouped together in the e-mail. Suppose we have
  the following sequence of failures:

  1. TaskA(a=1, b=1)
  2. TaskA(a=1, b=1)
  3. TaskA(a=2, b=1)
  4. TaskA(a=1, b=2)
  5. TaskB(a=1, b=1)

  For any setting of batch_mode, the batch e-mail will record 5 failures
  and mention them in the subject. The difference is in how they will
  be displayed in the body. Here are example bodies with error_messages
  set to 0.

  "all" only groups together failures for the exact same task:

  - TaskA(a=1, b=1) (2 failures)
  - TaskA(a=1, b=2) (1 failure)
  - TaskA(a=2, b=1) (1 failure)
  - TaskB(a=1, b=1) (1 failure)

  "family" groups together failures for tasks of the same family:

  - TaskA (4 failures)
  - TaskB (1 failure)

  "unbatched_params" groups together tasks that look the same after
  removing batched parameters. So if TaskA has a batch_method set for
  parameter a, we get the following:

  - TaskA(b=1) (3 failures)
  - TaskA(b=2) (1 failure)
  - TaskB(a=1, b=2) (1 failure)

  Defaults to "unbatched_params", which is identical to "all" if you are
  not using batched parameters.

error_lines
  Number of lines to include from each error message in the batch
  e-mail. This can be used to keep e-mails shorter while preserving the
  more useful information usually found near the bottom of stack traces.
  This can be set to 0 to include all lines. If you don't wish to see
  error messages, instead set ``error_messages`` to 0.
  Defaults to 20.

error_messages
  Number of messages to preserve for each task group. As most tasks that
  fail repeatedly do so for similar reasons each time, it's not usually
  necessary to keep every message. This controls how many messages are
  kept for each task or task group. The most recent error messages are
  kept. Set to 0 to not include error messages in the e-mails.
  Defaults to 1.

group_by_error_messages
  Quite often, a system or cluster failure will cause many disparate
  task types to fail for the same reason. This can cause a lot of noise
  in the batch e-mails. This cuts down on the noise by listing items
  with identical error messages together. Error messages are compared
  after limiting by ``error_lines``.
  Defaults to true.


[hadoop]
--------

Parameters controlling basic hadoop tasks

command
  Name of command for running hadoop from the command line. Defaults to
  "hadoop"

python_executable
  Name of command for running python from the command line. Defaults to
  "python"

scheduler
  Type of scheduler to use when scheduling hadoop jobs. Can be "fair" or
  "capacity". Defaults to "fair".

streaming_jar
  Path to your streaming jar. Must be specified to run streaming jobs.

version
  Version of hadoop used in your cluster. Can be "cdh3", "chd4", or
  "apache1". Defaults to "cdh4".


[hdfs]
------

Parameters controlling the use of snakebite to speed up hdfs queries.

client
  Client to use for most hadoop commands. Options are "snakebite",
  "snakebite_with_hadoopcli_fallback", "webhdfs" and "hadoopcli". Snakebite is
  much faster, so use of it is encouraged. webhdfs is fast and works with
  Python 3 as well, but has not been used that much in the wild.
  Both snakebite and webhdfs requires you to install it separately on
  the machine. Defaults to "hadoopcli".

client_version
  Optionally specifies hadoop client version for snakebite.

effective_user
  Optionally specifies the effective user for snakebite.

namenode_host
  The hostname of the namenode. Needed for snakebite if
  snakebite_autoconfig is not set.

namenode_port
  The port used by snakebite on the namenode. Needed for snakebite if
  snakebite_autoconfig is not set.

snakebite_autoconfig
  If true, attempts to automatically detect the host and port of the
  namenode for snakebite queries. Defaults to false.

tmp_dir
  Path to where Luigi will put temporary files on hdfs


[hive]
------

Parameters controlling hive tasks

command
  Name of the command used to run hive on the command line. Defaults to
  "hive".

hiverc_location
  Optional path to hive rc file.

metastore_host
  Hostname for metastore.

metastore_port
  Port for hive to connect to metastore host.

release
  If set to "apache", uses a hive client that better handles apache
  hive output. All other values use the standard client Defaults to
  "cdh4".


[kubernetes]
------------

Parameters controlling Kubernetes Job Tasks

auth_method
  Authorization method to access the cluster.
  Options are "kubeconfig_" or "service-account_"

kubeconfig_path
  Path to kubeconfig file, for cluster authentication.
  It defaults to ``~/.kube/config``, which is the default location when
  using minikube_.
  When auth_method is "service-account" this property is ignored.

max_retrials
  Maximum number of retrials in case of job failure.

.. _service-account: http://kubernetes.io/docs/user-guide/kubeconfig-file
.. _kubeconfig: http://kubernetes.io/docs/user-guide/service-accounts
.. _minikube: http://kubernetes.io/docs/getting-started-guides/minikube


[mysql]
-------

Parameters controlling use of MySQL targets

marker_table
  Table in which to store status of table updates. This table will be
  created if it doesn't already exist. Defaults to "table_updates".


[postgres]
----------

Parameters controlling the use of Postgres targets

local_tmp_dir
  Directory in which to temporarily store data before writing to
  postgres. Uses system default if not specified.

marker_table
  Table in which to store status of table updates. This table will be
  created if it doesn't already exist. Defaults to "table_updates".


[prometheus]
------------

use_task_family_in_labels
  Should task family be used as a prometheus bucket label.
  Default value is true.

task_parameters_to_use_in_labels
  List of task arguments' names used as additional prometheus bucket labels.
  Passed in a form of a json list.


[redshift]
----------

Parameters controlling the use of Redshift targets

marker_table
  Table in which to store status of table updates. This table will be
  created if it doesn't already exist. Defaults to "table_updates".

.. _resources-config:

[resources]
-----------

This section can contain arbitrary keys. Each of these specifies the
amount of a global resource that the scheduler can allow workers to use.
The scheduler will prevent running jobs with resources specified from
exceeding the counts in this section. Unspecified resources are assumed
to have limit 1. Example resources section for a configuration with 2
hive resources and 1 mysql resource:

.. code:: ini

  [resources]
  hive=2
  mysql=1

Note that it was not necessary to specify the 1 for mysql here, but it
is good practice to do so when you have a fixed set of resources.

.. _retcode-config:

[retcode]
---------

Configure return codes for the Luigi binary. In the case of multiple return
codes that could apply, for example a failing task and missing data, the
*numerically greatest* return code is returned.

We recommend that you copy this set of exit codes to your ``luigi.cfg`` file:

.. code:: ini

  [retcode]
  # The following return codes are the recommended exit codes for Luigi
  # They are in increasing level of severity (for most applications)
  already_running=10
  missing_data=20
  not_run=25
  task_failed=30
  scheduling_error=35
  unhandled_exception=40

already_running
  This can happen in two different cases. Either the local lock file was taken
  at the time the invocation starts up. Or, the central scheduler have reported
  that some tasks could not have been run, because other workers are already
  running the tasks.
missing_data
  For when an :py:class:`~luigi.task.ExternalTask` is not complete, and this
  caused the worker to give up.  As an alternative to fiddling with this, see
  the [worker] keep_alive option.
not_run
  For when a task is not granted run permission by the scheduler. Typically
  because of lack of resources, because the task has been already run by
  another worker or because the attempted task is in DISABLED state.
  Connectivity issues with the central scheduler might also cause this.
  This does not include the cases for which a run is not allowed due to missing
  dependencies (missing_data) or due to the fact that another worker is currently
  running the task (already_running).
task_failed
  For signaling that there were last known to have failed. Typically because
  some exception have been raised.
scheduling_error
  For when a task's ``complete()`` or ``requires()`` method fails with an
  exception, or when the limit number of tasks is reached.
unhandled_exception
  For internal Luigi errors.  Defaults to 4, since this type of error
  probably will not recover over time.

If you customize return codes, prefer to set them in range 128 to 255 to avoid
conflicts. Return codes in range 0 to 127 are reserved for possible future use
by Luigi contributors.

[scalding]
----------

Parameters controlling running of scalding jobs

scala_home
  Home directory for scala on your machine. Defaults to either
  SCALA_HOME or /usr/share/scala if SCALA_HOME is unset.

scalding_home
  Home directory for scalding on your machine. Defaults to either
  SCALDING_HOME or /usr/share/scalding if SCALDING_HOME is unset.

scalding_provided
  Provided directory for scalding on your machine. Defaults to either
  SCALDING_HOME/provided or /usr/share/scalding/provided

scalding_libjars
  Libjars directory for scalding on your machine. Defaults to either
  SCALDING_HOME/libjars or /usr/share/scalding/libjars


.. _scheduler-config:

[scheduler]
-----------

Parameters controlling scheduler behavior

batch_emails
  Whether to send batch e-mails for failures and disables rather than
  sending immediate disable e-mails and just relying on workers to send
  immediate batch e-mails.
  Defaults to false.

disable_hard_timeout
  Hard time limit after which tasks will be disabled by the server if
  they fail again, in seconds. It will disable the task if it fails
  **again** after this amount of time. E.g. if this was set to 600
  (i.e. 10 minutes), and the task first failed at 10:00am, the task would
  be disabled if it failed again any time after 10:10am. Note: This setting
  does not consider the values of the ``retry_count`` or
  ``disable_window`` settings.

retry_count
  Number of times a task can fail within ``disable_window`` before
  the scheduler will automatically disable it. If not set, the scheduler
  will not automatically disable jobs.

disable_persist
  Number of seconds for which an automatic scheduler disable lasts.
  Defaults to 86400 (1 day).

disable_window
  Number of seconds during which ``retry_count`` failures must
  occur in order for an automatic disable by the scheduler. The
  scheduler forgets about disables that have occurred longer ago than
  this amount of time. Defaults to 3600 (1 hour).

max_shown_tasks
  .. versionadded:: 1.0.20

  The maximum number of tasks returned in a task_list api call. This
  will restrict the number of tasks shown in task lists in the
  visualiser. Small values can alleviate frozen browsers when there are
  too many done tasks. This defaults to 100000 (one hundred thousand).

max_graph_nodes
  .. versionadded:: 2.0.0

  The maximum number of nodes returned by a dep_graph or
  inverse_dep_graph api call. Small values can greatly speed up graph
  display in the visualiser by limiting the number of nodes shown. Some
  of the nodes that are not sent to the visualiser will still show up as
  dependencies of nodes that were sent. These nodes are given TRUNCATED
  status.

record_task_history
  If true, stores task history in a database. Defaults to false.

remove_delay
  Number of seconds to wait before removing a task that has no
  stakeholders. Defaults to 600 (10 minutes).

retry_delay
  Number of seconds to wait after a task failure to mark it pending
  again. Defaults to 900 (15 minutes).

state_path
  Path in which to store the Luigi scheduler's state. When the scheduler
  is shut down, its state is stored in this path. The scheduler must be
  shut down cleanly for this to work, usually with a kill command. If
  the kill command includes the -9 flag, the scheduler will not be able
  to save its state. When the scheduler is started, it will load the
  state from this path if it exists. This will restore all scheduled
  jobs and other state from when the scheduler last shut down.

  Sometimes this path must be deleted when restarting the scheduler
  after upgrading Luigi, as old state files can become incompatible
  with the new scheduler. When this happens, all workers should be
  restarted after the scheduler both to become compatible with the
  updated code and to reschedule the jobs that the scheduler has now
  forgotten about.

  This defaults to /var/lib/luigi-server/state.pickle

worker_disconnect_delay
  Number of seconds to wait after a worker has stopped pinging the
  scheduler before removing it and marking all of its running tasks as
  failed. Defaults to 60.

pause_enabled
  If false, disables pause/unpause operations and hides the pause toggle from
  the visualiser.

send_messages
  When true, the scheduler is allowed to send messages to running tasks and
  the central scheduler provides a simple prompt per task to send messages.
  Defaults to true.

metrics_collector
  Optional setting allowing Luigi to use a contribution to collect metrics
  about the pipeline to a third-party. By default this uses the default metric
  collector that acts as a shell and does nothing. The currently available
  options are "datadog", "prometheus" and "custom". If it's custom the
  'metrics_custom_import' needs to be set.

metrics_custom_import
  Optional setting allowing Luigi to import a custom subclass of MetricsCollector
  at runtime. The string should be formatted like "module.sub_module.ClassName".


[sendgrid]
----------

These parameters control sending error e-mails through SendGrid.

apikey
  API key of the SendGrid account.


[smtp]
------

These parameters control the smtp server setup.

host
  Hostname for sending mail through smtp. Defaults to localhost.

local_hostname
  If specified, overrides the FQDN of localhost in the HELO/EHLO
  command.

no_tls
  If true, connects to smtp without TLS. Defaults to false.

password
  Password to log in to your smtp server. Must be specified for
  username to have an effect.

port
  Port number for smtp on smtp_host. Defaults to 0.

ssl
  If true, connects to smtp through SSL. Defaults to false.

timeout
  Sets the number of seconds after which smtp attempts should time out.
  Defaults to 10.

username
  Username to log in to your smtp server, if necessary.


[spark]
-------

Parameters controlling the default execution of :py:class:`~luigi.contrib.spark.SparkSubmitTask` and :py:class:`~luigi.contrib.spark.PySparkTask`:

.. deprecated:: 1.1.1
   :py:class:`~luigi.contrib.spark.SparkJob`, :py:class:`~luigi.contrib.spark.Spark1xJob` and :py:class:`~luigi.contrib.spark.PySpark1xJob`
    are deprecated. Please use :py:class:`~luigi.contrib.spark.SparkSubmitTask` or :py:class:`~luigi.contrib.spark.PySparkTask`.

spark_submit
  Command to run in order to submit spark jobs. Default: ``"spark-submit"``

master
  Master url to use for ``spark_submit``. Example: local[*], spark://masterhost:7077. Default: Spark default (Prior to 1.1.1: yarn-client)

deploy_mode
    Whether to launch the driver programs locally ("client") or on one of the worker machines inside the cluster ("cluster"). Default: Spark default

jars
    Comma-separated list of local jars to include on the driver and executor classpaths. Default: Spark default

packages
    Comma-separated list of packages to link to on the driver and executors

py_files
    Comma-separated list of .zip, .egg, or .py files to place on the ``PYTHONPATH`` for Python apps. Default: Spark default

files
    Comma-separated list of files to be placed in the working directory of each executor. Default: Spark default

conf:
    Arbitrary Spark configuration property in the form Prop=Value|Prop2=Value2. Default: Spark default

properties_file
    Path to a file from which to load extra properties. Default: Spark default

driver_memory
    Memory for driver (e.g. 1000M, 2G). Default: Spark default

driver_java_options
    Extra Java options to pass to the driver. Default: Spark default

driver_library_path
    Extra library path entries to pass to the driver. Default: Spark default

driver_class_path
    Extra class path entries to pass to the driver. Default: Spark default

executor_memory
    Memory per executor (e.g. 1000M, 2G). Default: Spark default

*Configuration for Spark submit jobs on Spark standalone with cluster deploy mode only:*

driver_cores
    Cores for driver. Default: Spark default

supervise
    If given, restarts the driver on failure. Default: Spark default

*Configuration for Spark submit jobs on Spark standalone and Mesos only:*

total_executor_cores
    Total cores for all executors. Default: Spark default

*Configuration for Spark submit jobs on YARN only:*

executor_cores
    Number of cores per executor. Default: Spark default

queue
    The YARN queue to submit to. Default: Spark default

num_executors
    Number of executors to launch. Default: Spark default

archives
    Comma separated list of archives to be extracted into the working directory of each executor. Default: Spark default

hadoop_conf_dir
  Location of the hadoop conf dir. Sets HADOOP_CONF_DIR environment variable
  when running spark. Example: /etc/hadoop/conf

*Extra configuration for PySparkTask jobs:*

py_packages
    Comma-separated list of local packages (in your python path) to be distributed to the cluster.

*Parameters controlling the execution of SparkJob jobs (deprecated):*


[task_history]
--------------

Parameters controlling storage of task history in a database

db_connection
  Connection string for connecting to the task history db using
  sqlalchemy.


[execution_summary]
-------------------

Parameters controlling execution summary of a worker

summary_length
  Maximum number of tasks to show in an execution summary.  If the value is 0,
  then all tasks will be displayed.  Default value is 5.


[webhdfs]
---------

port
  The port to use for webhdfs. The normal namenode port is probably on a
  different port from this one.

user
  Perform file system operations as the specified user instead of $USER.  Since
  this parameter is not honored by any of the other hdfs clients, you should
  think twice before setting this parameter.

client_type
  The type of client to use. Default is the "insecure" client that requires no
  authentication. The other option is the "kerberos" client that uses kerberos
  authentication.

[datadog]
---------

api_key
  The api key found in the account settings of Datadog under the API
  sections.
app_key
  The application key found in the account settings of Datadog under the API
  sections.
default_tags
  Optional settings that adds the tag to all the metrics and events sent to
  Datadog. Default value is "application:luigi".
environment
  Allows you to tweak multiple environment to differentiate between production,
  staging or development metrics within Datadog. Default value is "development".
statsd_host
  The host that has the statsd instance to allow Datadog to send statsd metric. Default value is "localhost".
statsd_port
  The port on the host that allows connection to the statsd host. Defaults value is 8125.
metric_namespace
  Optional prefix to add to the beginning of every metric sent to Datadog.
  Default value is "luigi".


Per Task Retry-Policy
---------------------

Luigi also supports defining ``retry_policy`` per task.

.. code-block:: python

    class GenerateWordsFromHdfs(luigi.Task):

       retry_count = 2

        ...

    class GenerateWordsFromRDBM(luigi.Task):

       retry_count = 5

        ...

    class CountLetters(luigi.Task):

        def requires(self):
            return [GenerateWordsFromHdfs()]

        def run():
            yield GenerateWordsFromRDBM()

        ...

If none of retry-policy fields is defined per task, the field value will be **default** value which is defined in luigi config file.

To make luigi sticks to the given retry-policy, be sure you run luigi worker with ``keep_alive`` config. Please check ``keep_alive`` config in :ref:`worker-config` section.

Retry-Policy Fields
-------------------

The fields below are in retry-policy and they can be defined per task.

* ``retry_count``
* ``disable_hard_timeout``
* ``disable_window``


================================================
FILE: doc/design_and_limitations.rst
================================================
Design and limitations
----------------------

Luigi is the successor to a couple of attempts that we weren't fully happy with.
We learned a lot from our mistakes and some design decisions include:

-  Straightforward command-line integration.
-  As little boilerplate as possible.
-  Focus on job scheduling and dependency resolution, not a particular platform.
   In particular, this means no limitation to Hadoop.
   Though Hadoop/HDFS support is built-in and is easy to use,
   this is just one of many types of things you can run.
-  A file system abstraction where code doesn't have to care about where files are located.
-  Atomic file system operations through this abstraction.
   If a task crashes it won't lead to a broken state.
-  The dependencies are decentralized.
   No big config file in XML.
   Each task just specifies which inputs it needs and cross-module dependencies are trivial.
-  A web server that renders the dependency graph and does locking, etc for free.
-  Trivial to extend with new file systems, file formats, and job types.
   You can easily write jobs that inserts a Tokyo Cabinet into Cassandra.
   Adding support for new systems is generally not very hard.
   (Feel free to send us a patch when you're done!)
-  Date algebra included.
-  Lots of unit tests of the most basic stuff.

It wouldn't be fair not to mention some limitations with the current design:

-  Its focus is on batch processing so
   it's probably less useful for near real-time pipelines or continuously running processes.
-  The assumption is that each task is a sizable chunk of work.
   While you can probably schedule a few thousand jobs,
   it's not meant to scale beyond tens of thousands.
-  Luigi does not support distribution of execution.
   When you have workers running thousands of jobs daily, this starts to matter,
   because the worker nodes get overloaded.
   There are some ways to mitigate this (trigger from many nodes, use resources),
   but none of them are ideal.
-  Luigi does not come with built-in triggering, and you still need to rely on something like
   crontab to trigger workflows periodically.

Also, it should be mentioned that Luigi is named after the world's second most famous plumber.


================================================
FILE: doc/example_top_artists.rst
================================================
Example – Top Artists
---------------------

This is a very simplified case of something we do at Spotify a lot.
All user actions are logged to Google Cloud Storage (previously HDFS) where
we run a bunch of processing jobs to transform the data. The processing code itself is implemented
in a scalable data processing framework, such as Scio, Scalding, or Spark, but the jobs
are orchestrated with Luigi.
At some point we might end up with
a smaller data set that we can bulk ingest into Cassandra, Postgres, or
other storage suitable for serving or exploration.

For the purpose of this exercise, we want to aggregate all streams,
find the top 10 artists and then put the results into Postgres.

This example is also available in
`examples/top_artists.py `_.

Step 1 - Aggregate Artist Streams
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. code:: python

    class AggregateArtists(luigi.Task):
        date_interval = luigi.DateIntervalParameter()

        def output(self):
            return luigi.LocalTarget("data/artist_streams_%s.tsv" % self.date_interval)

        def requires(self):
            return [Streams(date) for date in self.date_interval]

        def run(self):
            artist_count = defaultdict(int)

            for input in self.input():
                with input.open('r') as in_file:
                    for line in in_file:
                        timestamp, artist, track = line.strip().split()
                        artist_count[artist] += 1

            with self.output().open('w') as out_file:
                for artist, count in artist_count.iteritems():
                    print(artist, count, file=out_file)

Note that this is just a portion of the file ``examples/top_artists.py``.
In particular, ``Streams`` is defined as a :class:`~luigi.task.Task`,
acting as a dependency for ``AggregateArtists``.
In addition, ``luigi.run()`` is called if the script is executed directly,
allowing it to be run from the command line.

There are several pieces of this snippet that deserve more explanation.

-  Any :class:`~luigi.task.Task` may be customized by instantiating one
   or more :class:`~luigi.parameter.Parameter` objects on the class level.
-  The :func:`~luigi.task.Task.output` method tells Luigi where the result
   of running the task will end up. The path can be some function of the
   parameters.
-  The :func:`~luigi.task.Task.requires` tasks specifies other tasks that
   we need to perform this task. In this case it's an external dump named
   *Streams* which takes the date as the argument.
-  For plain Tasks, the :func:`~luigi.task.Task.run` method implements the
   task. This could be anything, including calling subprocesses, performing
   long running number crunching, etc. For some subclasses of
   :class:`~luigi.task.Task` you don't have to implement the ``run``
   method. For instance, for the :class:`~luigi.contrib.hadoop.JobTask`
   subclass you implement a *mapper* and *reducer* instead.
-  :class:`~luigi.LocalTarget` is a built in class that makes it
   easy to read/write from/to the local filesystem. It also makes all file operations
   atomic, which is nice in case your script crashes for any reason.

Running this Locally
~~~~~~~~~~~~~~~~~~~~

Try running this using eg.

.. code-block:: console

    $ cd examples
    $ luigi --module top_artists AggregateArtists --local-scheduler --date-interval 2012-06

Note that  *top_artists* needs to be in your PYTHONPATH, or else this can produce an error (*ImportError: No module named top_artists*). Add the current working directory to the command PYTHONPATH with:

.. code-block:: console

    $ PYTHONPATH='.' luigi --module top_artists AggregateArtists --local-scheduler --date-interval 2012-06

You can also try to view the manual using ``--help`` which will give you an
overview of the options.

Running the command again will do nothing because the output file is
already created.
In that sense, any task in Luigi is *idempotent*
because running it many times gives the same outcome as running it once.
Note that unlike Makefile, the output will not be recreated when any of
the input files is modified.
You need to delete the output file
manually.

The ``--local-scheduler`` flag tells Luigi not to connect to a scheduler
server. This is not recommended for other purpose than just testing
things.

Step 1b - Aggregate artists with Spark
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

While Luigi can process data inline, it is normally used to orchestrate external programs that
perform the actual processing. In this example, we will demonstrate how top artists instead can be
read from HDFS and calculated with Spark, orchestrated by Luigi.

.. code:: python

    class AggregateArtistsSpark(luigi.contrib.spark.SparkSubmitTask):
        date_interval = luigi.DateIntervalParameter()

        app = 'top_artists_spark.py'
        master = 'local[*]'

        def output(self):
            return luigi.contrib.hdfs.HdfsTarget("data/artist_streams_%s.tsv" % self.date_interval)

        def requires(self):
            return [StreamsHdfs(date) for date in self.date_interval]

        def app_options(self):
            # :func:`~luigi.task.Task.input` returns the targets produced by the tasks in
            # `~luigi.task.Task.requires`.
            return [','.join([p.path for p in self.input()]),
                    self.output().path]


:class:`luigi.contrib.hadoop.SparkSubmitTask` doesn't require you to implement a
:func:`~luigi.task.Task.run` method. Instead, you specify the command line parameters to send
to ``spark-submit``, as well as any other configuration specific to Spark.

Python code for the Spark job is found below.

.. code:: python

    import operator
    import sys
    from pyspark.sql import SparkSession


    def main(argv):
        input_paths = argv[1].split(',')
        output_path = argv[2]

        spark = SparkSession.builder.getOrCreate()

        streams = spark.read.option('sep', '\t').csv(input_paths[0])
        for stream_path in input_paths[1:]:
            streams.union(spark.read.option('sep', '\t').csv(stream_path))

        # The second field is the artist
        counts = streams \
            .map(lambda row: (row[1], 1)) \
            .reduceByKey(operator.add)

        counts.write.option('sep', '\t').csv(output_path)


    if __name__ == '__main__':
        sys.exit(main(sys.argv))


In a typical deployment scenario, the Luigi orchestration definition above as well as the
Pyspark processing code would be packaged into a deployment package, such as a container image. The
processing code does not have to be implemented in Python, any program can be packaged in the
image and run from Luigi.


Step 2 – Find the Top Artists
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

At this point, we've counted the number of streams for each artists,
for the full time period.
We are left with a large file that contains
mappings of artist -> count data, and we want to find the top 10 artists.
Since we only have a few hundred thousand artists, and
calculating artists is nontrivial to parallelize,
we choose to do this not as a Hadoop job, but just as a plain old for-loop in Python.

.. code:: python

    class Top10Artists(luigi.Task):
        date_interval = luigi.DateIntervalParameter()
        use_hadoop = luigi.BoolParameter()

        def requires(self):
            if self.use_hadoop:
                return AggregateArtistsSpark(self.date_interval)
            else:
                return AggregateArtists(self.date_interval)

        def output(self):
            return luigi.LocalTarget("data/top_artists_%s.tsv" % self.date_interval)

        def run(self):
            top_10 = nlargest(10, self._input_iterator())
            with self.output().open('w') as out_file:
                for streams, artist in top_10:
                    print(self.date_interval.date_a, self.date_interval.date_b, artist, streams, file=out_file)

        def _input_iterator(self):
            with self.input().open('r') as in_file:
                for line in in_file:
                    artist, streams = line.strip().split()
                    yield int(streams), int(artist)

The most interesting thing here is that this task (*Top10Artists*)
defines a dependency on the previous task (*AggregateArtists*).
This means that if the output of *AggregateArtists* does not exist,
the task will run before *Top10Artists*.

.. code-block:: console

    $ luigi --module examples.top_artists Top10Artists --local-scheduler --date-interval 2012-07

This will run both tasks.

Step 3 - Insert into Postgres
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

This mainly serves as an example of a specific subclass *Task* that
doesn't require any code to be written.
It's also an example of how you can define task templates that
you can reuse for a lot of different tasks.

.. code:: python

    class ArtistToplistToDatabase(luigi.contrib.postgres.CopyToTable):
        date_interval = luigi.DateIntervalParameter()
        use_hadoop = luigi.BoolParameter()

        host = "localhost"
        database = "toplists"
        user = "luigi"
        password = "abc123"  # ;)
        table = "top10"

        columns = [("date_from", "DATE"),
                   ("date_to", "DATE"),
                   ("artist", "TEXT"),
                   ("streams", "INT")]

        def requires(self):
            return Top10Artists(self.date_interval, self.use_hadoop)

Just like previously, this defines a recursive dependency on the
previous task. If you try to build the task, that will also trigger
building all its upstream dependencies.

Using the Central Planner
~~~~~~~~~~~~~~~~~~~~~~~~~

The ``--local-scheduler`` flag tells Luigi not to connect to a central scheduler.
This is recommended in order to get started and or for development purposes.
At the point where you start putting things in production
we strongly recommend running the central scheduler server.
In addition to providing locking
so that the same task is not run by multiple processes at the same time,
this server also provides a pretty nice visualization of your current work flow.

If you drop the ``--local-scheduler`` flag,
your script will try to connect to the central planner,
by default at localhost port 8082.
If you run

.. code-block:: console

    $ luigid

in the background and then run your task without the ``--local-scheduler`` flag,
then your script will now schedule through a centralized server.
You need `Tornado `__ for this to work.

Launching http://localhost:8082 should show something like this:

.. figure:: web_server.png
   :alt: Web server screenshot

Web server screenshot
Looking at the dependency graph
for any of the tasks yields something like this:

.. figure:: aggregate_artists.png
   :alt: Aggregate artists screenshot

Aggregate artists screenshot

In production, you'll want to run the centralized scheduler.
See: :doc:`central_scheduler` for more information.


================================================
FILE: doc/execution_model.rst
================================================
Execution Model
---------------

Luigi has a quite simple model for execution and triggering.

Workers and task execution
~~~~~~~~~~~~~~~~~~~~~~~~~~

The most important aspect is that *no execution is transferred*.
When you run a Luigi workflow,
the worker schedules all tasks, and
also executes the tasks within the process.

    .. figure:: execution_model.png
       :alt: Execution model

The benefit of this scheme is that
it's super easy to debug since all execution takes place in the process.
It also makes deployment a non-event.
During development,
you typically run the Luigi workflow from the command line,
whereas when you deploy it,
you can trigger it using crontab or any other scheduler.

The downside is that Luigi doesn't give you scalability for free.
In practice this is not a problem until you start running thousands of tasks.

Isn't the point of Luigi to automate and schedule these workflows?
To some extent.
Luigi helps you *encode the dependencies* of tasks and build up chains.
Furthermore, Luigi's scheduler makes sure that there's a centralized view of the dependency graph and
that the same job will not be executed by multiple workers simultaneously.

Scheduler
~~~~~~~~~

A client only starts the ``run()`` method of a task when the single-threaded
central scheduler has permitted it. Since the number of tasks is usually very
small (in comparison with the petabytes of data one task is processing), we
can afford the convenience of a simple centralised server.

.. figure:: https://tarrasch.github.io/luigid-basics-jun-2015/img/50.gif
   :alt: Scheduling gif

The gif is from `this presentation
`__, which is about the
client and server interaction.

Triggering tasks
~~~~~~~~~~~~~~~~

Luigi does not include its own triggering, so you have to rely on an external scheduler
such as crontab to actually trigger the workflows.

In practice, it's not a big hurdle because Luigi avoids all the mess typically caused by it.
Scheduling a complex workflow is fairly trivial using eg. crontab.

In the future, Luigi might implement its own triggering.
The dependency on crontab (or any external triggering mechanism) is a bit awkward and it would be nice to avoid.

Trigger example
^^^^^^^^^^^^^^^

For instance, if you have an external data dump that arrives every day and that your workflow depends on it,
you write a workflow that depends on this data dump.
Crontab can then trigger this workflow *every minute* to check if the data has arrived.
If it has, it will run the full dependency graph.

.. code:: python

    # my_tasks.py

    class DataDump(luigi.ExternalTask):
        date = luigi.DateParameter()
        def output(self): return luigi.contrib.hdfs.HdfsTarget(self.date.strftime('/var/log/dump/%Y-%m-%d.txt'))

    class AggregationTask(luigi.Task):
        date = luigi.DateParameter()
        window = luigi.IntParameter()
        def requires(self): return [DataDump(self.date - datetime.timedelta(i)) for i in xrange(self.window)]
        def run(self): run_some_cool_stuff(self.input())
        def output(self): return luigi.contrib.hdfs.HdfsTarget('/aggregated-%s-%d' % (self.date, self.window))

    class RunAll(luigi.Task):
        ''' Dummy task that triggers execution of a other tasks'''
        def requires(self):
            for window in [3, 7, 14]:
                for d in xrange(10): # guarantee that aggregations were run for the past 10 days
                   yield AggregationTask(datetime.date.today() - datetime.timedelta(d), window)

In your cronline you would then have something like

.. code:: console

    30 0 * * * my-user luigi RunAll --module my_tasks


You can trigger this as much as you want from crontab, and
even across multiple machines, because
the central scheduler will make sure at most one of each ``AggregationTask`` task is run simultaneously.
Note that this might actually mean multiple tasks can be run because
there are instances with different parameters, and
this can give you some form of parallelization
(eg. ``AggregationTask(2013-01-09)`` might run in parallel with ``AggregationTask(2013-01-08)``).

Of course,
some Task types (eg. ``HadoopJobTask``) can transfer execution to other places, but
this is up to each Task to define.


================================================
FILE: doc/index.rst
================================================
.. Luigi documentation master file, created by
   sphinx-quickstart on Sat Feb  8 00:56:43 2014.
   You can adapt this file completely to your liking, but it should at least
   contain the root `toctree` directive.

.. include:: ../README.rst

Table of Contents
-----------------

.. toctree::
   :maxdepth: 2

   example_top_artists.rst
   workflows.rst
   tasks.rst
   parameters.rst
   running_luigi.rst
   central_scheduler.rst
   execution_model.rst
   luigi_patterns.rst
   configuration.rst
   logging.rst
   design_and_limitations.rst
   mypy.rst

API Reference
-------------

.. autosummary::
   :toctree: api
   :recursive:

   luigi


Indices and tables
==================

* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`


================================================
FILE: doc/logging.rst
================================================
Configure logging
-----------------


Config options:
~~~~~~~~~~~~~~~

Some config options for config [core] section

log_level
    The default log level to use when no logging_conf_file is set. Must be
    a valid name of a `Python log level
    `_.
    Default is ``DEBUG``.
logging_conf_file
      Location of the logging configuration file.
no_configure_logging
    If true, logging is not configured. Defaults to false.


Config section
~~~~~~~~~~~~~~

If you're use TOML for configuration file, you can configure logging
via ``logging`` section in this file. See `example
`_
for more details.

Luigid CLI options:
~~~~~~~~~~~~~~~~~~~

``--background``
    Run daemon in background mode. Disable logging setup
    and set up log level to INFO for root logger.
``--logdir``
    set logging with INFO level and output in ``$logdir/luigi-server.log`` file


Worker CLI options:
~~~~~~~~~~~~~~~~~~~

``--logging-conf-file``
    Configuration file for logging.
``--log-level``
    Default log level.
    Available values: NOTSET, DEBUG, INFO, WARNING, ERROR, CRITICAL.
    Default DEBUG. See `Python documentation
    `_
    For information about levels difference.


Configuration options resolution order:
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

1. no_configure_logging option
2. ``--background``
3. ``--logdir``
4. ``--logging-conf-file``
5. logging_conf_file option
6. ``logging`` section
7. ``--log-level``
8. log_level option


================================================
FILE: doc/luigi_patterns.rst
================================================
Luigi Patterns
--------------

Code Reuse
~~~~~~~~~~

One nice thing about Luigi is that it's super easy to depend on tasks defined in other repos.
It's also trivial to have "forks" in the execution path,
where the output of one task may become the input of many other tasks.

Currently, no semantics for "intermediate" output is supported,
meaning that all output will be persisted indefinitely.
The upside of that is that if you try to run X -> Y, and Y crashes,
you can resume with the previously built X.
The downside is that you will have a lot of intermediate results on your file system.
A useful pattern is to put these files in a special directory and
have some kind of periodical garbage collection clean it up.

Triggering Many Tasks
~~~~~~~~~~~~~~~~~~~~~

A convenient pattern is to have a dummy Task at the end of several
dependency chains, so you can trigger a multitude of pipelines by
specifying just one task in command line, similarly to how e.g. `make `_
works.

.. code:: python

    class AllReports(luigi.WrapperTask):
        date = luigi.DateParameter(default=datetime.date.today())
        def requires(self):
            yield SomeReport(self.date)
            yield SomeOtherReport(self.date)
            yield CropReport(self.date)
            yield TPSReport(self.date)
            yield FooBarBazReport(self.date)

This simple task will not do anything itself, but will invoke a bunch of
other tasks. Per each invocation, Luigi will perform as many of the pending
jobs as possible (those which have all their dependencies present).

You'll need to use :class:`~luigi.task.WrapperTask` for this instead of the usual Task class, because this job will not produce any output of its own, and as such needs a way to indicate when it's complete. This class is used for tasks that only wrap other tasks and that by definition are done if all their requirements exist.

Triggering recurring tasks
~~~~~~~~~~~~~~~~~~~~~~~~~~

A common requirement is to have a daily report (or something else)
produced every night. Sometimes for various reasons tasks will keep
crashing or lacking their required dependencies for more than a day
though, which would lead to a missing deliverable for some date. Oops.

To ensure that the above AllReports task is eventually completed for
every day (value of date parameter), one could e.g. add a loop in
requires method to yield dependencies on the past few days preceding
self.date. Then, so long as Luigi keeps being invoked, the backlog of
jobs would catch up nicely after fixing intermittent problems.

Luigi actually comes with a reusable tool for achieving this, called
:class:`~luigi.tools.range.RangeDailyBase` (resp. :class:`~luigi.tools.range.RangeHourlyBase`). Simply putting

.. code-block:: console

	luigi --module all_reports RangeDailyBase --of AllReports --start 2015-01-01

in your crontab will easily keep gaps from occurring from 2015-01-01
onwards. NB - it will not always loop over everything from 2015-01-01
till current time though, but rather a maximum of 3 months ago by
default - see :class:`~luigi.tools.range.RangeDailyBase` documentation for this and more knobs
for tweaking behavior. See also Monitoring below.

Efficiently triggering recurring tasks
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

RangeDailyBase, described above, is named like that because a more
efficient subclass exists, :class:`~luigi.tools.range.RangeDaily` (resp. :class:`~luigi.tools.range.RangeHourly`), tailored for
hundreds of task classes scheduled concurrently with contiguousness
requirements spanning years (which would incur redundant completeness
checks and scheduler overload using the naive looping approach.) Usage:

.. code-block:: console

	luigi --module all_reports RangeDaily --of AllReports --start 2015-01-01

It has the same knobs as RangeDailyBase, with some added requirements.
Namely the task must implement an efficient bulk_complete method, or
must be writing output to file system Target with date parameter value
consistently represented in the file path.

Backfilling tasks
~~~~~~~~~~~~~~~~~

Also a common use case, sometimes you have tweaked existing recurring
task code and you want to schedule recomputation of it over an interval
of dates for that or another reason. Most conveniently it is achieved
with the above described range tools, just with both start (inclusive)
and stop (exclusive) parameters specified:

.. code-block:: console

	luigi --module all_reports RangeDaily --of AllReportsV2 --start 2014-10-31 --stop 2014-12-25

Propagating parameters with Range
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Some tasks you want to recur may include additional parameters which need to be configured.
The Range classes provide a parameter which accepts a :class:`~luigi.parameter.DictParameter`
and passes any parameters onwards for this purpose.

.. code-block:: console

	luigi RangeDaily --of MyTask --start 2014-10-31 --of-params '{"my_string_param": "123", "my_int_param": 123}'

Alternatively, you can specify parameters at the task family level (as described :ref:`here `),
however these will not appear in the task name for the upstream Range task which
can have implications in how the scheduler and visualizer handle task instances.

.. code-block:: console

	luigi RangeDaily --of MyTask --start 2014-10-31 --MyTask-my-param 123

.. _batch_method:

Batching multiple parameter values into a single run
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Sometimes it'll be faster to run multiple jobs together as a single
batch rather than running them each individually. When this is the case,
you can mark some parameters with a batch_method in their constructor
to tell the worker how to combine multiple values. One common way to do
this is by simply running the maximum value. This is good for tasks that
overwrite older data when a newer one runs. You accomplish this by
setting the batch_method to max, like so:

.. code-block:: python

    class A(luigi.Task):
        date = luigi.DateParameter(batch_method=max)

What's exciting about this is that if you send multiple As to the
scheduler, it can combine them and return one. So if
``A(date=2016-07-28)``, ``A(date=2016-07-29)`` and
``A(date=2016-07-30)`` are all ready to run, you will start running
``A(date=2016-07-30)``. While this is running, the scheduler will show
``A(date=2016-07-28)``, ``A(date=2016-07-29)`` as batch running while
``A(date=2016-07-30)`` is running. When ``A(date=2016-07-30)`` is done
running and becomes FAILED or DONE, the other two tasks will be updated
to the same status.

If you want to limit how big a batch can get, simply set max_batch_size.
So if you have

.. code-block:: python

    class A(luigi.Task):
        date = luigi.DateParameter(batch_method=max)

        max_batch_size = 10

then the scheduler will batch at most 10 jobs together. You probably do
not want to do this with the max batch method, but it can be helpful if
you use other methods. You can use any method that takes a list of
parameter values and returns a single parameter value.

If you have two max batch parameters, you'll get the max values for both
of them. If you have parameters that don't have a batch method, they'll
be aggregated separately. So if you have a class like

.. code-block:: python

    class A(luigi.Task):
        p1 = luigi.IntParameter(batch_method=max)
        p2 = luigi.IntParameter(batch_method=max)
        p3 = luigi.IntParameter()

and you create tasks ``A(p1=1, p2=2, p3=0)``, ``A(p1=2, p2=3, p3=0)``,
``A(p1=3, p2=4, p3=1)``, you'll get them batched as
``A(p1=2, p2=3, p3=0)`` and ``A(p1=3, p2=4, p3=1)``.

Note that batched tasks do not take up :ref:`resources-config`, only the
task that ends up running will use resources. The scheduler only checks
that there are sufficient resources for each task individually before
batching them all together.

Tasks that regularly overwrite the same data source
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

If you are overwriting of the same data source with every run, you'll
need to ensure that two batches can't run at the same time. You can do
this pretty easily by setting batch_method to max and setting a unique
resource:

.. code-block:: python

    class A(luigi.Task):
        date = luigi.DateParameter(batch_method=max)

        resources = {'overwrite_resource': 1}

Now if you have multiple tasks such as ``A(date=2016-06-01)``,
``A(date=2016-06-02)``, ``A(date=2016-06-03)``, the scheduler will just
tell you to run the highest available one and mark the lower ones as
batch_running. Using a unique resource will prevent multiple tasks from
writing to the same location at the same time if a new one becomes
available while others are running.

Avoiding concurrent writes to a single file
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Updating a single file from several tasks is almost always a bad idea, and you
need to be very confident that no other good solution exists before doing this.
If, however, you have no other option, then you will probably at least need to ensure that
no two tasks try to write to the file _simultaneously_.

By turning 'resources' into a Python property, it can return a value dependent on
the task parameters or other dynamic attributes:

.. code-block:: python

    class A(luigi.Task):
        ...

        @property
        def resources(self):
            return { self.important_file_name: 1 }

Since, by default, resources have a usage limit of 1, no two instances of Task A
will now run if they have the same `important_file_name` property.

Decreasing resources of running tasks
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

At scheduling time, the luigi scheduler needs to be aware of the maximum
resource consumption a task might have once it runs. For some tasks, however,
it can be beneficial to decrease the amount of consumed resources between two
steps within their run method (e.g. after some heavy computation). In this
case, a different task waiting for that particular resource can already be
scheduled.

.. code-block:: python

    class A(luigi.Task):

        # set maximum resources a priori
        resources = {"some_resource": 3}

        def run(self):
            # do something
            ...

            # decrease consumption of "some_resource" by one
            self.decrease_running_resources({"some_resource": 1})

            # continue with reduced resources
            ...

Monitoring task pipelines
~~~~~~~~~~~~~~~~~~~~~~~~~

Luigi comes with some existing ways in :py:mod:`luigi.notifications` to receive
notifications whenever tasks crash. Email is the most common way.

The above mentioned range tools for recurring tasks not only implement
reliable scheduling for you, but also emit events which you can use to
set up delay monitoring. That way you can implement alerts for when
jobs are stuck for prolonged periods lacking input data or otherwise
requiring attention.

.. _AtomicWrites:

Atomic Writes Problem
~~~~~~~~~~~~~~~~~~~~~

A very common mistake done by luigi plumbers is to write data partially to the
final destination, that is, not atomically. The problem arises because
completion checks in luigi are exactly as naive as running
:meth:`luigi.target.Target.exists`. And in many cases it just means to check if
a folder exist on disk. During the time we have partially written data, a task
depending on that output would think its input is complete. This can have
devestating effects, as in `the thanksgiving bug
`__.

The concept can be illustrated by imagining that we deal with data stored on
local disk and by running commands:

.. code-block:: console

    # This the BAD way
    $ mkdir /outputs/final_output
    $ big-slow-calculation > /outputs/final_output/foo.data

As stated earlier, the problem is that only partial data exists for a duration,
yet we consider the data to be :meth:`~luigi.task.Task.complete` because the
output folder already exists. Here is a robust version of this:

.. code-block:: console

    # This is the good way
    $ mkdir /outputs/final_output-tmp-123456
    $ big-slow-calculation > /outputs/final_output-tmp-123456/foo.data
    $ mv --no-target-directory --no-clobber /outputs/final_output{-tmp-123456,}
    $ [[ -d /outputs/final_output-tmp-123456 ]] && rm -r /outputs/final_output-tmp-123456

Indeed, the good way is not as trivial. It involves coming up with a unique
directory name and a pretty complex ``mv`` line, the reason ``mv`` need all
those is because we don't want ``mv`` to move a directory into a potentially
existing directory. A directory could already exist in exceptional cases, for
example when central locking fails and the same task would somehow run twice at
the same time. Lastly, in the exceptional case where the file was never moved,
one might want to remove the temporary directory that never got used.

Note that this was an example where the storage was on local disk. But for
every storage (hard disk file, hdfs file, database table, etc.) this procedure
will look different. But do every luigi user need to implement that complexity?
Nope, thankfully luigi developers are aware of these and luigi comes with many
built-in solutions. In the case of you're dealing with a file system
(:class:`~luigi.target.FileSystemTarget`), you should consider using
:meth:`~luigi.target.FileSystemTarget.temporary_path`. For other targets, you
should ensure that the way you're writing your final output directory is
atomic.

Sending messages to tasks
~~~~~~~~~~~~~~~~~~~~~~~~~

The central scheduler is able to send messages to particular tasks. When a running task accepts
messages, it can access a `multiprocessing.Queue `__
object storing incoming messages. You can implement custom behavior to react and respond to
messages:

.. code-block:: python

    class Example(luigi.Task):

        # common task setup
        ...

        # configure the task to accept all incoming messages
        accepts_messages = True

        def run(self):
            # this example runs some loop and listens for the
            # "terminate" message, and responds to all other messages
            for _ in some_loop():
                # check incomming messages
                if not self.scheduler_messages.empty():
                    msg = self.scheduler_messages.get()
                    if msg.content == "terminate":
                        break
                    else:
                        msg.respond("unknown message")

            # finalize
            ...

Messages can be sent right from the scheduler UI which also displays responses (if any). Note that
this feature is only available when the scheduler is configured to send messages (see the :ref:`scheduler-config` config), and the task is configured to accept them.

Gathering custom metrics from tasks' executions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The central scheduler is able to gather custom metrics from tasks' executions with help of
custom metrics collector (see the :ref:`scheduler-config` config). To obtain custom metrics,
you need to implement:

#. Custom metrics collector class inheriting from
   :class:`~luigi.metrics.MetricsCollector` (or derived) and implementing the
   :meth:`~luigi.metrics.MetricsCollector.handle_task_statistics`
   method (default one does nothing). This method will be called for each task
   that has been executed everytime, when
   :meth:`~luigi.worker.TaskStatusReporter.report_task_statistics` is called.
   For instance, following metrics collector adds monitoring tasks' execution
   time and memory usage:

   .. code-block:: python

       class MetricsCollector(PrometheusMetricsCollector):
           def __init__(self, *args, **kwargs):
               super().__init__(*args, **kwargs)
               self.task_run_execution_time = Gauge(
                   'luigi_task_run_execution_time_seconds',
                   'luigi task run method execution time in seconds',
                   self.labels,
                   registry=self.registry
               )
               self.task_execution_memory = Gauge(
                   'luigi_task_max_memory_megabytes',
                   'luigi task run method max memory usage in megabytes',
                   self.labels,
                   registry=self.registry
               )

           def handle_task_statistics(self, task, statistics):
               if "elapsed" in statistics:
                   self.task_run_execution_time.labels(**self._generate_task_labels(task)).set(statistics["elapsed"])
               if "memory" in statistics:
                   self.task_execution_memory.labels(**self._generate_task_labels(task)).set(statistics["memory"])

#. Custom task context manager (see the :ref:`worker-config` config),
   which in `__exit__` method would call
   :meth:`~luigi.worker.TaskStatusReporter.report_task_statistics` method with
   the statistics dictionary. For instance, following task context manager collects
   task execution time and memory usage:

   .. code-block:: python

       class TaskContext:
           def __init__(self, task_process):
               self._task_process = task_process
               self._start = None

           def __enter__(self):
               self._start = time.perf_counter()
               return self

           def __exit__(self, exc_type, exc_val, exc_tb):
               assert self._start is not None
               elapsed = time.perf_counter() - self._start
               used_memory = max(
                   resource.getrusage(resource.RUSAGE_SELF).ru_maxrss, resource.getrusage(resource.RUSAGE_CHILDREN).ru_maxrss
               )
               logging.getLogger("luigi-interface").info(
                   f'Task {self._task_process.task}: time: {elapsed:.2f}s, memory: {used_memory / 1024:.2f}MB '
               )
               self._task_process.status_reporter.report_task_statistics({"memory": used_memory / 1024, "elapsed": elapsed})


================================================
FILE: doc/mypy.rst
================================================
Mypy plugin
--------------

Mypy plugin provides type checking for ``luigi.Task`` using Mypy.

Require Python 3.8 or later.

How to use
~~~~~~~~~~

Configure Mypy to use this plugin by adding the following to your ``mypy.ini`` file:

.. code:: ini

    [mypy]
    plugins = luigi.mypy

or by adding the following to your ``pyproject.toml`` file:

.. code:: toml

    [tool.mypy]
    plugins = ["luigi.mypy"]

Then, run Mypy as usual.

Examples
~~~~~~~~

For example the following code linted by Mypy:

.. code:: python

    import luigi


    class MyTask(luigi.Task):
        foo: int = luigi.IntParameter()
        bar: str = luigi.Parameter()

    MyTask(foo=1, bar='2')   # OK
    MyTask(foo='1', bar='2')  # Error: Argument 1 to "Foo" has incompatible type "str"; expected "int"


================================================
FILE: doc/parameters.rst
================================================
Parameters
----------

Parameters is the Luigi equivalent of creating a constructor for each Task.
Luigi requires you to declare these parameters by instantiating
:class:`~luigi.parameter.Parameter` objects on the class scope:

.. code:: python

    class DailyReport(luigi.contrib.hadoop.JobTask):
        date = luigi.DateParameter(default=datetime.date.today())
        # ...

By doing this, Luigi can take care of all the boilerplate code that
would normally be needed in the constructor.
Internally, the DailyReport object can now be constructed by running
``DailyReport(datetime.date(2012, 5, 10))`` or just ``DailyReport()``.
Luigi also creates a command line parser that automatically handles the
conversion from strings to Python types.
This way you can invoke the job on the command line eg. by passing ``--date 2012-05-10``.

The parameters are all set to their values on the Task object instance,
i.e.

.. code:: python

    d = DailyReport(datetime.date(2012, 5, 10))
    print(d.date)

will return the same date that the object was constructed with.
Same goes if you invoke Luigi on the command line.

.. _Parameter-instance-caching:

Instance caching
^^^^^^^^^^^^^^^^

Tasks are uniquely identified by their class name and values of their
parameters.
In fact, within the same worker, two tasks of the same class with
parameters of the same values are not just equal, but the same instance:

.. code:: python

    >>> import luigi
    >>> import datetime
    >>> class DateTask(luigi.Task):
    ...   date = luigi.DateParameter()
    ...
    >>> a = datetime.date(2014, 1, 21)
    >>> b = datetime.date(2014, 1, 21)
    >>> a is b
    False
    >>> c = DateTask(date=a)
    >>> d = DateTask(date=b)
    >>> c
    DateTask(date=2014-01-21)
    >>> d
    DateTask(date=2014-01-21)
    >>> c is d
    True

Insignificant parameters
^^^^^^^^^^^^^^^^^^^^^^^^

If a parameter is created with ``significant=False``,
it is ignored as far as the Task signature is concerned.
Tasks created with only insignificant parameters differing have the same signature but
are not the same instance:

.. code:: python

    >>> class DateTask2(DateTask):
    ...   other = luigi.Parameter(significant=False)
    ...
    >>> c = DateTask2(date=a, other="foo")
    >>> d = DateTask2(date=b, other="bar")
    >>> c
    DateTask2(date=2014-01-21)
    >>> d
    DateTask2(date=2014-01-21)
    >>> c.other
    'foo'
    >>> d.other
    'bar'
    >>> c is d
    False
    >>> hash(c) == hash(d)
    True

Parameter visibility
^^^^^^^^^^^^^^^^^^^^

Using :class:`~luigi.parameter.ParameterVisibility` you can configure parameter visibility. By default, all
parameters are public, but you can also set them hidden or private.

.. code:: python

    >>> import luigi
    >>> from luigi.parameter import ParameterVisibility
    
    >>> luigi.Parameter(visibility=ParameterVisibility.PRIVATE)

``ParameterVisibility.PUBLIC`` (default) - visible everywhere

``ParameterVisibility.HIDDEN`` - ignored in WEB-view, but saved into database if save db_history is true

``ParameterVisibility.PRIVATE`` - visible only inside task.

Parameter types
^^^^^^^^^^^^^^^

In the examples above, the *type* of the parameter is determined by using different
subclasses of :class:`~luigi.parameter.Parameter`. There are a few of them, like
:class:`~luigi.parameter.DateParameter`,
:class:`~luigi.parameter.DateIntervalParameter`,
:class:`~luigi.parameter.IntParameter`,
:class:`~luigi.parameter.FloatParameter`, etc.

Python is not a statically typed language and you don't have to specify the types
of any of your parameters.
You can simply use the base class :class:`~luigi.parameter.Parameter` if you don't care.

The reason you would use a subclass like :class:`~luigi.parameter.DateParameter`
is that Luigi needs to know its type for the command line interaction.
That's how it knows how to convert a string provided on the command line to
the corresponding type (i.e. datetime.date instead of a string).

.. _Parameter-class-level-parameters:

Setting parameter value for other classes
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

All parameters are also exposed on a class level on the command line interface.
For instance, say you have classes TaskA and TaskB:

.. code:: python

    class TaskA(luigi.Task):
        x = luigi.Parameter()

    class TaskB(luigi.Task):
        y = luigi.Parameter()


You can run ``TaskB`` on the command line: ``luigi TaskB --y 42``.
But you can also set the class value of ``TaskA`` by running
``luigi TaskB --y 42 --TaskA-x 43``.
This sets the value of ``TaskA.x`` to 43 on a *class* level.
It is still possible to override it inside Python if you instantiate ``TaskA(x=44)``.

All parameters can also be set from the configuration file.
For instance, you can put this in the config:

.. code:: ini

    [TaskA]
    x: 45


Just as in the previous case, this will set the value of ``TaskA.x`` to 45 on the *class* level.
And likewise, it is still possible to override it inside Python if you instantiate ``TaskA(x=44)``.

Parameter resolution order
^^^^^^^^^^^^^^^^^^^^^^^^^^

Parameters are resolved in the following order of decreasing priority:

1. Any value passed to the constructor, or task level value set on the command line (applies on an instance level)
2. Any value set on the command line (applies on a class level)
3. Any configuration option (applies on a class level)
4. Any default value provided to the parameter (applies on a class level)

See the :class:`~luigi.parameter.Parameter` class for more information.


================================================
FILE: doc/running_luigi.rst
================================================
Running Luigi
-------------

Running from the Command Line
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

The preferred way to run Luigi tasks is through the ``luigi`` command line tool
that will be installed with the pip package.

.. code-block:: python

    # my_module.py, available in your sys.path
    import luigi

    class MyTask(luigi.Task):
        x = luigi.IntParameter()
        y = luigi.IntParameter(default=45)

        def run(self):
            print(self.x + self.y)

Should be run like this

.. code-block:: console

        $ luigi --module my_module MyTask --x 123 --y 456 --local-scheduler

Or alternatively like this:

.. code-block:: console

        $ python -m luigi --module my_module MyTask --x 100 --local-scheduler

Note that if a parameter name contains '_', it should be replaced by '-'.
For example, if MyTask had a parameter called 'my_parameter':

.. code-block:: console

        $ luigi --module my_module MyTask --my-parameter 100 --local-scheduler

.. note:: Please make sure to always place task parameters behind the task family!


Running from Python code
^^^^^^^^^^^^^^^^^^^^^^^^

Another way to start tasks from Python code is using ``luigi.build(tasks, worker_scheduler_factory=None, **env_params)``
from ``luigi.interface`` module.

This way of running luigi tasks is useful if you want to get some dynamic parameters from another
source, such as database, or provide additional logic before you start tasks.

One notable difference is that ``build`` defaults to not using the identical process lock.
If you want to change this behaviour, just pass ``no_lock=False``.


.. code-block:: python

    class MyTask1(luigi.Task):
        x = luigi.IntParameter()
        y = luigi.IntParameter(default=0)

        def run(self):
            print(self.x + self.y)


    class MyTask2(luigi.Task):
        x = luigi.IntParameter()
        y = luigi.IntParameter(default=1)
        z = luigi.IntParameter(default=2)

        def run(self):
            print(self.x * self.y * self.z)


    if __name__ == '__main__':
        luigi.build([MyTask1(x=10), MyTask2(x=15, z=3)])


Also, it is possible to pass additional parameters to ``build`` such as host, port, workers and local_scheduler:

.. code-block:: python

    if __name__ == '__main__':
         luigi.build([MyTask1(x=1)], workers=5, local_scheduler=True)

To achieve some special requirements you can pass to ``build`` your  ``worker_scheduler_factory``
which will return your worker and/or scheduler implementations:

.. code-block:: python

    class MyWorker(Worker):
        # some custom logic


    class MyFactory:
      def create_local_scheduler(self):
          return scheduler.Scheduler(prune_on_get_work=True, record_task_history=False)

      def create_remote_scheduler(self, url):
          return rpc.RemoteScheduler(url)

      def create_worker(self, scheduler, worker_processes, assistant=False):
          # return your worker instance
          return MyWorker(
              scheduler=scheduler, worker_processes=worker_processes, assistant=assistant)


    if __name__ == '__main__':
        luigi.build([MyTask1(x=1)], worker_scheduler_factory=MyFactory())

In some cases (like task queue) it may be useful.



Response of luigi.build()/luigi.run()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

- **Default response** By default *luigi.build()/luigi.run()* returns True if there were no scheduling errors. This is the same as the attribute ``LuigiRunResult.scheduling_succeeded``.

- **Detailed response** This is a response of type :class:`~luigi.execution_summary.LuigiRunResult`. This is obtained by passing a keyword argument ``detailed_summary=True`` to *build/run*. This response contains detailed information about the jobs.

  .. code-block:: python

    if __name__ == '__main__':
         luigi_run_result = luigi.build(..., detailed_summary=True)
         print(luigi_run_result.summary_text)


Luigi on Windows
^^^^^^^^^^^^^^^^

Most Luigi functionality works on Windows. Exceptions:

- Specifying multiple worker processes using the ``workers`` argument for
  ``luigi.build``, or using the ``--workers`` command line argument. (Similarly,
  specifying ``--worker-force-multiprocessing``). For most programs, this will
  result in failure (a common sight is ``BrokenPipeError``). The reason is that
  worker processes are assumed to be forked from the main process. Forking is
  `not possible `_
  on Windows.
- Running the Luigi central scheduling server as a daemon (i.e. with ``--background``).
  Again, a Unix-only concept.


================================================
FILE: doc/tasks.rst
================================================
Tasks
-----

Tasks are where the execution takes place.
Tasks depend on each other and output targets.

An outline of how a task can look like:

    .. figure:: task_breakdown.png
       :alt: Task breakdown

.. _Task.requires:

Task.requires
~~~~~~~~~~~~~

The :func:`~luigi.task.Task.requires` method is used to specify dependencies on other Task object,
which might even be of the same class.
For instance, an example implementation could be

.. code:: python

    def requires(self):
        return OtherTask(self.date), DailyReport(self.date - datetime.timedelta(1))

In this case, the DailyReport task depends on two inputs created earlier,
one of which is the same class.
requires can return other Tasks in any way wrapped up within dicts/lists/tuples/etc.

Requiring another Task
~~~~~~~~~~~~~~~~~~~~~~

Note that :func:`~luigi.task.Task.requires` can *not* return a :class:`~luigi.target.Target` object.
If you have a simple Target object that is created externally
you can wrap it in a Task class like this:

.. code:: python

    class LogFiles(luigi.ExternalTask):
        def output(self):
            return luigi.contrib.hdfs.HdfsTarget('/log')

This also makes it easier to add parameters:

.. code:: python

    class LogFiles(luigi.ExternalTask):
        date = luigi.DateParameter()
        def output(self):
            return luigi.contrib.hdfs.HdfsTarget(self.date.strftime('/log/%Y-%m-%d'))

.. _Task.output:

Task.output
~~~~~~~~~~~

The :func:`~luigi.task.Task.output` method returns one or more :class:`~luigi.target.Target` objects.
Similarly to requires, you can return them wrapped up in any way that's convenient for you.
However we recommend that any :class:`~luigi.task.Task` only return one single :class:`~luigi.target.Target` in output.
If multiple outputs are returned,
atomicity will be lost unless the :class:`~luigi.task.Task` itself can ensure that each :class:`~luigi.target.Target` is atomically created.
(If atomicity is not of concern, then it is safe to return multiple :class:`~luigi.target.Target` objects.)

.. code:: python

    class DailyReport(luigi.Task):
        date = luigi.DateParameter()
        def output(self):
            return luigi.contrib.hdfs.HdfsTarget(self.date.strftime('/reports/%Y-%m-%d'))
        # ...

.. _Task.run:

Task.run
~~~~~~~~

The :func:`~luigi.task.Task.run` method now contains the actual code that is run.
When you are using Task.requires_ and Task.run_ Luigi breaks down everything into two stages.
First it figures out all dependencies between tasks,
then it runs everything.
The :func:`~luigi.task.Task.input` method is an internal helper method that just replaces all Task objects in requires
with their corresponding output.
An example:

.. code:: python

    class GenerateWords(luigi.Task):

        def output(self):
            return luigi.LocalTarget('words.txt')

        def run(self):

            # write a dummy list of words to output file
            words = [
                    'apple',
                    'banana',
                    'grapefruit'
                    ]

            with self.output().open('w') as f:
                for word in words:
                    f.write('{word}\n'.format(word=word))


    class CountLetters(luigi.Task):

        def requires(self):
            return GenerateWords()

        def output(self):
            return luigi.LocalTarget('letter_counts.txt')

        def run(self):

            # read in file as list
            with self.input().open('r') as infile:
                words = infile.read().splitlines()

            # write each word to output file with its corresponding letter count
            with self.output().open('w') as outfile:
                for word in words:
                    outfile.write(
                            '{word} | {letter_count}\n'.format(
                                word=word,
                                letter_count=len(word)
                                )
                            )

It's useful to note that if you're writing to a binary file, Luigi automatically
strips the ``'b'`` flag due to how atomic writes/reads work. In order to write a binary
file, such as a pickle file, you should instead use ``format=Nop`` when calling
LocalTarget. Following the above example:

.. code:: python

    from luigi.format import Nop

    class GenerateWords(luigi.Task):

        def output(self):
            return luigi.LocalTarget('words.pckl', format=Nop)

        def run(self):
            import pickle

            # write a dummy list of words to output file
            words = [
                    'apple',
                    'banana',
                    'grapefruit'
                    ]

            with self.output().open('w') as f:
                pickle.dump(words, f)


It is your responsibility to ensure that after running :func:`~luigi.task.Task.run`, the task is
complete, i.e. :func:`~luigi.task.Task.complete` returns ``True``. Unless you have overridden
:func:`~luigi.task.Task.complete`, :func:`~luigi.task.Task.run` should generate all the targets
defined as outputs. Luigi verifies that you adhere to the contract before running downstream
dependencies, and reports ``Unfulfilled dependencies at run time`` if a violation is detected.

.. _Task.input:

Task.input
~~~~~~~~~~

As seen in the example above, :func:`~luigi.task.Task.input` is a wrapper around Task.requires_ that
returns the corresponding Target objects instead of Task objects.
Anything returned by Task.requires_ will be transformed, including lists,
nested dicts, etc.
This can be useful if you have many dependencies:

.. code:: python

    class TaskWithManyInputs(luigi.Task):
        def requires(self):
            return {'a': TaskA(), 'b': [TaskB(i) for i in xrange(100)]}

        def run(self):
            f = self.input()['a'].open('r')
            g = [y.open('r') for y in self.input()['b']]


Dynamic dependencies
~~~~~~~~~~~~~~~~~~~~

Sometimes you might not know exactly what other tasks to depend on until runtime.
In that case, Luigi provides a mechanism to specify dynamic dependencies.
If you yield another :class:`~luigi.task.Task` in the Task.run_ method,
the current task will be suspended and the other task will be run.
You can also yield a list of tasks.

.. code:: python

    class MyTask(luigi.Task):
        def run(self):
            other_target = yield OtherTask()

            # dynamic dependencies resolve into targets
            f = other_target.open('r')


This mechanism is an alternative to Task.requires_ in case
you are not able to build up the full dependency graph before running the task.
It does come with some constraints:
the Task.run_ method will resume from scratch each time a new task is yielded.
In other words, you should make sure your Task.run_ method is idempotent.
(This is good practice for all Tasks in Luigi, but especially so for tasks with dynamic dependencies).
As this might entail redundant calls to tasks' :func:`~luigi.task.Task.complete` methods,
you should consider setting the "cache_task_completion" option in the :ref:`worker-config`.
To further control how dynamic task requirements are handled internally by worker nodes,
there is also the option to wrap dependent tasks by :class:`~luigi.task.DynamicRequirements`.

For an example of a workflow using dynamic dependencies, see
`examples/dynamic_requirements.py `_.


Task status tracking
~~~~~~~~~~~~~~~~~~~~

For long-running or remote tasks it is convenient to see extended status information not only on
the command line or in your logs but also in the GUI of the central scheduler. Luigi implements
dynamic status messages, progress bar and tracking urls which may point to an external monitoring system.
You can set this information using callbacks within Task.run_:

.. code:: python

    class MyTask(luigi.Task):
        def run(self):
            # set a tracking url
            self.set_tracking_url("http://...")

            # set status messages during the workload
            for i in range(100):
                # do some hard work here
                if i % 10 == 0:
                    self.set_status_message("Progress: %d / 100" % i)
                    # displays a progress bar in the scheduler UI
                    self.set_progress_percentage(i)


.. _Events:

Events and callbacks
~~~~~~~~~~~~~~~~~~~~

Luigi has a built-in event system that
allows you to register callbacks to events and trigger them from your own tasks.
You can both hook into some pre-defined events and create your own.
Each event handle is tied to a Task class and
will be triggered only from that class or
a subclass of it.
This allows you to effortlessly subscribe to events only from a specific class (e.g. for hadoop jobs).

.. code:: python

    @luigi.Task.event_handler(luigi.Event.SUCCESS)
    def celebrate_success(task):
        """Will be called directly after a successful execution
           of `run` on any Task subclass (i.e. all luigi Tasks)
        """
        ...

    @luigi.contrib.hadoop.JobTask.event_handler(luigi.Event.FAILURE)
    def mourn_failure(task, exception):
        """Will be called directly after a failed execution
           of `run` on any JobTask subclass
        """
        ...

    luigi.run()


But I just want to run a Hadoop job?
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The Hadoop code is integrated in the rest of the Luigi code because
we really believe almost all Hadoop jobs benefit from being part of some sort of workflow.
However, in theory, nothing stops you from using the :class:`~luigi.contrib.hadoop.JobTask` class (and also :class:`~luigi.contrib.hdfs.target.HdfsTarget`)
without using the rest of Luigi.
You can simply run it manually using

.. code:: python

    MyJobTask('abc', 123).run()

You can use the hdfs.target.HdfsTarget class anywhere by just instantiating it:

.. code:: python

    t = luigi.contrib.hdfs.target.HdfsTarget('/tmp/test.gz', format=format.Gzip)
    f = t.open('w')
    # ...
    f.close() # needed

.. _Task.priority:

Task priority
~~~~~~~~~~~~~

The scheduler decides which task to run next from
the set of all tasks that have all their dependencies met.
By default, this choice is pretty arbitrary,
which is fine for most workflows and situations.

If you want to have some control on the order of execution of available tasks,
you can set the ``priority`` property of a task,
for example as follows:

.. code:: python

    # A static priority value as a class constant:
    class MyTask(luigi.Task):
        priority = 100
        # ...

    # A dynamic priority value with a "@property" decorated method:
    class OtherTask(luigi.Task):
        @property
        def priority(self):
            if self.date > some_threshold:
                return 80
            else:
                return 40
        # ...

Tasks with a higher priority value will be picked before tasks with a lower priority value.
There is no predefined range of priorities,
you can choose whatever (int or float) values you want to use.
The default value is 0.

Warning: task execution order in Luigi is influenced by both dependencies and priorities, but
in Luigi dependencies come first.
For example:
if there is a task A with priority 1000 but still with unmet dependencies and
a task B with priority 1 without any pending dependencies,
task B will be picked first.

.. _Task.namespaces_famlies_and_ids:

Namespaces, families and ids
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

In order to avoid name clashes and to be able to have an identifier for tasks,
Luigi introduces the concepts *task_namespace*, *task_family* and
*task_id*. The namespace and family operate on class level meanwhile the task
id only exists on instance level. The concepts are best illustrated using code.

.. code:: python

    import luigi
    class MyTask(luigi.Task):
        my_param = luigi.Parameter()
        task_namespace = 'my_namespace'

    my_task = MyTask(my_param='hello')
    print(my_task)                      # --> my_namespace.MyTask(my_param=hello)

    print(my_task.get_task_namespace()) # --> my_namespace
    print(my_task.get_task_family())    # --> my_namespace.MyTask
    print(my_task.task_id)              # --> my_namespace.MyTask_hello_890907e7ce

    print(MyTask.get_task_namespace())  # --> my_namespace
    print(MyTask.get_task_family())     # --> my_namespace.MyTask
    print(MyTask.task_id)               # --> Error!

The full documentation for this machinery exists in the :py:mod:`~luigi.task` module.

Instance caching
~~~~~~~~~~~~~~~~

In addition to the stuff mentioned above,
Luigi also does some metaclass logic so that
if e.g. ``DailyReport(datetime.date(2012, 5, 10))`` is instantiated twice in the code,
it will in fact result in the same object.
See :ref:`Parameter-instance-caching` for more info


================================================
FILE: doc/workflows.rst
================================================
Building workflows
------------------

There are two fundamental building blocks of Luigi -
the :class:`~luigi.task.Task` class and the :class:`~luigi.target.Target` class.
Both are abstract classes and expect a few methods to be implemented.
In addition to those two concepts,
the :class:`~luigi.parameter.Parameter` class is an important concept that governs how a Task is run.

Target
~~~~~~

The :py:class:`~luigi.target.Target` class corresponds to a file on a disk,
a file on HDFS or some kind of a checkpoint, like an entry in a database.
Actually, the only method that Targets have to implement is the *exists*
method which returns True if and only if the Target exists.

In practice, implementing Target subclasses is rarely needed.
Luigi comes with a toolbox of several useful Targets.
In particular, :class:`~luigi.file.LocalTarget` and :class:`~luigi.contrib.hdfs.target.HdfsTarget`,
but there is also support for other file systems:
:class:`luigi.contrib.s3.S3Target`,
:class:`luigi.contrib.ssh.RemoteTarget`,
:class:`luigi.contrib.ftp.RemoteTarget`,
:class:`luigi.contrib.mysqldb.MySqlTarget`,
:class:`luigi.contrib.redshift.RedshiftTarget`, and several more.

Most of these targets, are file system-like.
For instance, :class:`~luigi.file.LocalTarget` and :class:`~luigi.contrib.hdfs.target.HdfsTarget` map to a file on the local drive or a file in HDFS.
In addition these also wrap the underlying operations to make them atomic.
They both implement the :func:`~luigi.file.LocalTarget.open` method which returns a stream object that
could be read (``mode='r'``) from or written to (``mode='w'``).

Luigi comes with Gzip support by providing ``format=format.Gzip``.
Adding support for other formats is pretty simple.

Task
~~~~

The :class:`~luigi.task.Task` class is a bit more conceptually interesting because this is
where computation is done.
There are a few methods that can be implemented to alter its behavior,
most notably :func:`~luigi.task.Task.run`, :func:`~luigi.task.Task.output` and :func:`~luigi.task.Task.requires`.

Tasks consume Targets that were created by some other task. They usually also output targets:

    .. figure:: task_with_targets.png
       :alt: Task and targets

You can define dependencies between *Tasks* using the :py:meth:`~luigi.task.Task.requires` method. See :doc:`/tasks` for more info.

    .. figure:: tasks_with_dependencies.png
       :alt: Tasks and dependencies

Each task defines its outputs using the :py:meth:`~luigi.task.Task.output` method.
Additionally, there is a helper method :py:meth:`~luigi.task.Task.input` that returns the corresponding Target classes for each Task dependency.

    .. figure:: tasks_input_output_requires.png
       :alt: Tasks and methods

.. _Parameter:

Parameter
~~~~~~~~~

The Task class corresponds to some type of job that is run, but in
general you want to allow some form of parameterization of it.
For instance, if your Task class runs a Hadoop job to create a report every night,
you probably want to make the date a parameter of the class.
See :doc:`/parameters` for more info.

    .. figure:: task_parameters.png
       :alt: Tasks with parameters

Dependencies
~~~~~~~~~~~~

Using tasks, targets, and parameters, Luigi lets you express arbitrary dependencies in *code*, rather than using some kind of awkward config DSL.
This is really useful because in the real world, dependencies are often very messy.
For instance, some examples of the dependencies you might encounter:

    .. figure:: parameters_date_algebra.png
       :alt: Dependencies with date algebra

    .. figure:: parameters_recursion.png
       :alt: Dependencies with recursion

    .. figure:: parameters_enum.png
       :alt: Dependencies with enums

(These diagrams are from a `Luigi presentation in late 2014 at NYC Data Science meetup `_)


================================================
FILE: examples/__init__.py
================================================
# -*- coding: utf-8 -*-
#
# Copyright 2012-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


================================================
FILE: examples/config.toml
================================================

[hdfs]
client = "hadoopcli"
namenode_host = "localhost"
namenode_port = 50030

# LOGGING

[logging]
version = 1
disable_existing_loggers = false

# logs format
[logging.formatters.simple]
format = "{levelname:8} {asctime} {module}:{lineno} {message}"
style = "{"
datefmt = "%Y-%m-%d %H:%M:%S"

# write logs to console
[logging.handlers.console]
level = "DEBUG"
class = "logging.StreamHandler"
formatter = "simple"

# luigi worker logging
[logging.loggers.luigi-interface]
handlers = ["console"]
level = "INFO"
disabled = false
propagate = false

# luigid logging
[logging.loggers.luigi]
handlers = ["console"]
level = "INFO"
disabled = false
propagate = false

# luigid builded on tornado
[logging.loggers.tornado]
handlers = ["console"]
level = "INFO"
disabled = false
propagate = false

# custom logger for "project"
[logging.loggers.project]
handlers = ["console"]
level = "DEBUG"
disabled = false
propagate = false


================================================
FILE: examples/dynamic_requirements.py
================================================
# -*- coding: utf-8 -*-
#
# Copyright 2012-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import os
import random as rnd
import time

import luigi


class Configuration(luigi.Task):
    seed = luigi.IntParameter()

    def output(self):
        """
        Returns the target output for this task.
        In this case, a successful execution of this task will create a file on the local filesystem.

        :return: the target output for this task.
        :rtype: object (:py:class:`luigi.target.Target`)
        """
        return luigi.LocalTarget("/tmp/Config_%d.txt" % self.seed)

    def run(self):
        time.sleep(5)
        rnd.seed(self.seed)

        result = ",".join([str(x) for x in rnd.sample(list(range(300)), rnd.randint(7, 25))])
        with self.output().open("w") as f:
            f.write(result)


class Data(luigi.Task):
    magic_number = luigi.IntParameter()

    def output(self):
        """
        Returns the target output for this task.
        In this case, a successful execution of this task will create a file on the local filesystem.

        :return: the target output for this task.
        :rtype: object (:py:class:`luigi.target.Target`)
        """
        return luigi.LocalTarget("/tmp/Data_%d.txt" % self.magic_number)

    def run(self):
        time.sleep(1)
        with self.output().open("w") as f:
            f.write("%s" % self.magic_number)


class Dynamic(luigi.Task):
    seed = luigi.IntParameter(default=1)

    def output(self):
        """
        Returns the target output for this task.
        In this case, a successful execution of this task will create a file on the local filesystem.

        :return: the target output for this task.
        :rtype: object (:py:class:`luigi.target.Target`)
        """
        return luigi.LocalTarget("/tmp/Dynamic_%d.txt" % self.seed)

    def run(self):
        # This could be done using regular requires method
        config = self.clone(Configuration)
        yield config

        with config.output().open() as f:
            data = [int(x) for x in f.read().split(",")]

        # ... but not this
        data_dependent_deps = [Data(magic_number=x) for x in data]
        yield data_dependent_deps

        with self.output().open("w") as f:
            f.write("Tada!")

        # and in case data is rather long, consider wrapping the requirements
        # in DynamicRequirements and optionally define a custom complete method
        def custom_complete(complete_fn):
            # example: Data() stores all outputs in the same directory, so avoid doing len(data) fs
            # calls but rather check only the first, and compare basenames for the rest
            # (complete_fn defaults to "lambda task: task.complete()" but can also include caching)
            if not complete_fn(data_dependent_deps[0]):
                return False
            paths = [task.output().path for task in data_dependent_deps]
            basenames = os.listdir(os.path.dirname(paths[0]))  # a single fs call
            return all(os.path.basename(path) in basenames for path in paths)

        yield luigi.DynamicRequirements(data_dependent_deps, custom_complete)


if __name__ == "__main__":
    luigi.run()


================================================
FILE: examples/elasticsearch_index.py
================================================
# -*- coding: utf-8 -*-
#
# Copyright 2012-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import datetime
import json

import luigi
from luigi.contrib.esindex import CopyToIndex


class FakeDocuments(luigi.Task):
    """
    Generates a local file containing 5 elements of data in JSON format.
    """

    #: the date parameter.
    date = luigi.DateParameter(default=datetime.date.today())

    def run(self):
        """
        Writes data in JSON format into the task's output target.

        The data objects have the following attributes:

        * `_id` is the default Elasticsearch id field,
        * `text`: the text,
        * `date`: the day when the data was created.

        """
        today = datetime.date.today()
        with self.output().open("w") as output:
            for i in range(5):
                output.write(json.dumps({"_id": i, "text": "Hi %s" % i, "date": str(today)}))
                output.write("\n")

    def output(self):
        """
        Returns the target output for this task.
        In this case, a successful execution of this task will create a file on the local filesystem.

        :return: the target output for this task.
        :rtype: object (:py:class:`luigi.target.Target`)
        """
        return luigi.LocalTarget(path="/tmp/_docs-%s.ldj" % self.date)


class IndexDocuments(CopyToIndex):
    """
    This task loads JSON data contained in a :py:class:`luigi.target.Target` into an ElasticSearch index.

    This task's input will the target returned by :py:meth:`~.FakeDocuments.output`.

    This class uses :py:meth:`luigi.contrib.esindex.CopyToIndex.run`.

    After running this task you can run:

    .. code-block:: console

        $ curl "localhost:9200/example_index/_search?pretty"

    to see the indexed documents.

    To see the update log, run

    .. code-block:: console

        $ curl "localhost:9200/update_log/_search?q=target_index:example_index&pretty"

    To cleanup both indexes run:

    .. code-block:: console

        $ curl -XDELETE "localhost:9200/example_index"
        $ curl -XDELETE "localhost:9200/update_log/_query?q=target_index:example_index"

    """

    #: date task parameter (default = today)
    date = luigi.DateParameter(default=datetime.date.today())

    #: the name of the index in ElasticSearch to be updated.
    index = "example_index"
    #: the name of the document type.
    doc_type = "greetings"
    #: the host running the ElasticSearch service.
    host = "localhost"
    #: the port used by the ElasticSearch service.
    port = 9200

    def requires(self):
        """
        This task's dependencies:

        * :py:class:`~.FakeDocuments`

        :return: object (:py:class:`luigi.task.Task`)
        """
        return FakeDocuments()


if __name__ == "__main__":
    luigi.run(["IndexDocuments", "--local-scheduler"])


================================================
FILE: examples/execution_summary_example.py
================================================
# -*- coding: utf-8 -*-
#
# Copyright 2015-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
You can run this example like this:

    .. code:: console

            $ luigi --module examples.execution_summary_example examples.EntryPoint --local-scheduler
            ...
            ... lots of spammy output
            ...
            INFO: There are 11 pending tasks unique to this worker
            INFO: Worker Worker(salt=843361665, workers=1, host=arash-spotify-T440s, username=arash, pid=18534) was stopped. Shutting down Keep-Alive thread
            INFO:
            ===== Luigi Execution Summary =====

            Scheduled 218 tasks of which:
            * 195 complete ones were encountered:
                - 195 examples.Bar(num=5...199)
            * 1 ran successfully:
                - 1 examples.Boom(...)
            * 22 were left pending, among these:
                * 1 were missing external dependencies:
                    - 1 MyExternal()
                * 21 had missing dependencies:
                    - 1 examples.EntryPoint()
                    - examples.Foo(num=100, num2=16) and 9 other examples.Foo
                    - 10 examples.DateTask(date=1998-03-23...1998-04-01, num=5)

            This progress looks :| because there were missing external dependencies

            ===== Luigi Execution Summary =====
"""

import datetime

import luigi


class MyExternal(luigi.ExternalTask):
    def complete(self):
        return False


class Boom(luigi.Task):
    task_namespace = "examples"
    this_is_a_really_long_I_mean_way_too_long_and_annoying_parameter = luigi.IntParameter()

    def run(self):
        print("Running Boom")

    def requires(self):
        for i in range(5, 200):
            yield Bar(i)


class Foo(luigi.Task):
    task_namespace = "examples"
    num = luigi.IntParameter()
    num2 = luigi.IntParameter()

    def run(self):
        print("Running Foo")

    def requires(self):
        yield MyExternal()
        yield Boom(0)


class Bar(luigi.Task):
    task_namespace = "examples"
    num = luigi.IntParameter()

    def run(self):
        self.output().open("w").close()

    def output(self):
        return luigi.LocalTarget("/tmp/bar/%d" % self.num)


class DateTask(luigi.Task):
    task_namespace = "examples"
    date = luigi.DateParameter()
    num = luigi.IntParameter()

    def run(self):
        print("Running DateTask")

    def requires(self):
        yield MyExternal()
        yield Boom(0)


class EntryPoint(luigi.Task):
    task_namespace = "examples"

    def run(self):
        print("Running EntryPoint")

    def requires(self):
        for i in range(10):
            yield Foo(100, 2 * i)
        for i in range(10):
            yield DateTask(datetime.date(1998, 3, 23) + datetime.timedelta(days=i), 5)


================================================
FILE: examples/foo.py
================================================
# -*- coding: utf-8 -*-
#
# Copyright 2012-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
You can run this example like this:

    .. code:: console

            $ rm -rf '/tmp/bar'
            $ luigi --module examples.foo examples.Foo --workers 2 --local-scheduler

"""

import time

import luigi


class Foo(luigi.WrapperTask):
    task_namespace = "examples"

    def run(self):
        print("Running Foo")

    def requires(self):
        for i in range(10):
            yield Bar(i)


class Bar(luigi.Task):
    task_namespace = "examples"
    num = luigi.IntParameter()

    def run(self):
        time.sleep(1)
        self.output().open("w").close()

    def output(self):
        """
        Returns the target output for this task.

        :return: the target output for this task.
        :rtype: object (:py:class:`~luigi.target.Target`)
        """
        time.sleep(1)
        return luigi.LocalTarget("/tmp/bar/%d" % self.num)


================================================
FILE: examples/foo_complex.py
================================================
# -*- coding: utf-8 -*-
#
# Copyright 2012-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
You can run this example like this:

    .. code:: console

            $ rm -rf '/tmp/bar'
            $ luigi --module examples.foo_complex examples.Foo --workers 2 --local-scheduler

"""

import random
import time

import luigi

max_depth = 10
max_total_nodes = 50
current_nodes = 0


class Foo(luigi.Task):
    task_namespace = "examples"

    def run(self):
        print("Running Foo")

    def requires(self):
        global current_nodes
        for i in range(30 // max_depth):
            current_nodes += 1
            yield Bar(i)


class Bar(luigi.Task):
    task_namespace = "examples"

    num = luigi.IntParameter()

    def run(self):
        time.sleep(1)
        self.output().open("w").close()

    def requires(self):
        global current_nodes

        if max_total_nodes > current_nodes:
            valor = int(random.uniform(1, 30))
            for i in range(valor // max_depth):
                current_nodes += 1
                yield Bar(current_nodes)

    def output(self):
        """
        Returns the target output for this task.

        :return: the target output for this task.
        :rtype: object (:py:class:`~luigi.target.Target`)
        """
        time.sleep(1)
        return luigi.LocalTarget("/tmp/bar/%d" % self.num)


================================================
FILE: examples/ftp_experiment_outputs.py
================================================
# -*- coding: utf-8 -*-
#
# Copyright 2012-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import luigi
from luigi.contrib.ftp import RemoteTarget

#: the FTP server
HOST = "some_host"
#: the username
USER = "user"
#: the password
PWD = "some_password"


class ExperimentTask(luigi.ExternalTask):
    """
    This class represents something that was created elsewhere by an external process,
    so all we want to do is to implement the output method.
    """

    def output(self):
        """
        Returns the target output for this task.
        In this case, a successful execution of this task will create a file that will be created in a FTP server.

        :return: the target output for this task.
        :rtype: object (:py:class:`~luigi.target.Target`)
        """
        return RemoteTarget("/experiment/output1.txt", HOST, username=USER, password=PWD)

    def run(self):
        """
        The execution of this task will write 4 lines of data on this task's target output.
        """
        with self.output().open("w") as outfile:
            print("data 0 200 10 50 60", file=outfile)
            print("data 1 190 9 52 60", file=outfile)
            print("data 2 200 10 52 60", file=outfile)
            print("data 3 195 1 52 60", file=outfile)


class ProcessingTask(luigi.Task):
    """
    This class represents something that was created elsewhere by an external process,
    so all we want to do is to implement the output method.
    """

    def requires(self):
        """
        This task's dependencies:

        * :py:class:`~.ExperimentTask`

        :return: object (:py:class:`luigi.task.Task`)
        """
        return ExperimentTask()

    def output(self):
        """
        Returns the target output for this task.
        In this case, a successful execution of this task will create a file on the local filesystem.

        :return: the target output for this task.
        :rtype: object (:py:class:`~luigi.target.Target`)
        """
        return luigi.LocalTarget("/tmp/processeddata.txt")

    def run(self):
        avg = 0.0
        elements = 0
        sumval = 0.0

        # Target objects are a file system/format abstraction and this will return a file stream object
        # NOTE: self.input() actually returns the ExperimentTask.output() target
        for line in self.input().open("r"):
            values = line.split(" ")
            avg += float(values[2])
            sumval += float(values[3])
            elements = elements + 1

        # average
        avg = avg / elements

        # save calculated values
        with self.output().open("w") as outfile:
            print(avg, sumval, file=outfile)


if __name__ == "__main__":
    luigi.run()


================================================
FILE: examples/hello_world.py
================================================
"""
You can run this example like this:

    .. code:: console

            $ luigi --module examples.hello_world examples.HelloWorldTask --local-scheduler

If that does not work, see :ref:`CommandLine`.
"""

import luigi


class HelloWorldTask(luigi.Task):
    task_namespace = "examples"

    def run(self):
        print("{task} says: Hello world!".format(task=self.__class__.__name__))


if __name__ == "__main__":
    luigi.run(["examples.HelloWorldTask", "--workers", "1", "--local-scheduler"])


================================================
FILE: examples/kubernetes.py
================================================
# -*- coding: utf-8 -*-
#
# Copyright 2015 Outlier Bio, LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Example Kubernetes Job Task.

Requires:

- pykube: ``pip install pykube-ng``
- A local minikube custer up and running: http://kubernetes.io/docs/getting-started-guides/minikube/

**WARNING**: For Python versions < 3.5 the kubeconfig file must point to a Kubernetes API
hostname, and NOT to an IP address.

You can run this code example like this:

    .. code:: console
        $ luigi --module examples.kubernetes_job PerlPi --local-scheduler

Running this code will create a pi-luigi-uuid kubernetes job within the cluster
pointed to by the default context in "~/.kube/config".

If running within a kubernetes cluster, set auth_method = "service-account" to
access the local cluster.
"""

# import os
# import luigi
from luigi.contrib.kubernetes import KubernetesJobTask


class PerlPi(KubernetesJobTask):
    name = "pi"
    max_retrials = 3
    spec_schema = {"containers": [{"name": "pi", "image": "perl", "command": ["perl", "-Mbignum=bpi", "-wle", "print bpi(2000)"]}]}

    # defining the two functions below allows for dependency checking,
    # but isn't a requirement
    # def signal_complete(self):
    #     with self.output().open('w') as output:
    #         output.write('')
    #
    # def output(self):
    #     target = os.path.join("/tmp", "PerlPi")
    #     return luigi.LocalTarget(target)


================================================
FILE: examples/per_task_retry_policy.py
================================================
# -*- coding: utf-8 -*-

"""
You can run this example like this:

    .. code:: console

            $ luigi --module examples.per_task_retry_policy examples.PerTaskRetryPolicy --worker-keep-alive \
            --local-scheduler --scheduler-retry-delay 5  --logging-conf-file test/testconfig/logging.cfg

            ...
            ... lots of spammy output
            ...
            DEBUG: ErrorTask1__99914b932b task num failures is 1 and limit is 5
            DEBUG: ErrorTask2__99914b932b task num failures is 1 and limit is 2
            DEBUG: DynamicErrorTask1__99914b932b task num failures is 1 and limit is 3
            DEBUG: ErrorTask1__99914b932b task num failures is 2 and limit is 5
            DEBUG: ErrorTask2__99914b932b task num failures is 2 and limit is 2
            DEBUG: ErrorTask2__99914b932b task num failures limit(2) is exceeded
            DEBUG: DynamicErrorTask1__99914b932b task num failures is 2 and limit is 3
            DEBUG: ErrorTask1__99914b932b task num failures is 3 and limit is 5
            DEBUG: DynamicErrorTask1__99914b932b task num failures is 3 and limit is 3
            DEBUG: DynamicErrorTask1__99914b932b task num failures limit(3) is exceeded
            DEBUG: ErrorTask1__99914b932b task num failures is 4 and limit is 5
            DEBUG: ErrorTask1__99914b932b task num failures is 5 and limit is 5
            DEBUG: ErrorTask1__99914b932b task num failures limit(5) is exceeded
            INFO:
            ===== Luigi Execution Summary =====

            Scheduled 8 tasks of which:
            * 2 ran successfully:
                - 1 SuccessSubTask1()
                - 1 SuccessTask1()
            * 3 failed:
                - 1 DynamicErrorTask1()
                - 1 ErrorTask1()
                - 1 ErrorTask2()
            * 3 were left pending, among these:
                * 1 were missing external dependencies:
                    - 1 DynamicErrorTaskSubmitter()
                * 1 had failed dependencies:
                    - 1 examples.PerTaskRetryPolicy()
                * 1 had missing dependencies:
                    - 1 examples.PerTaskRetryPolicy()
                * 1 was not granted run permission by the scheduler:
                    - 1 DynamicErrorTaskSubmitter()

            This progress looks :( because there were failed tasks

            ===== Luigi Execution Summary =====
"""

import luigi


class PerTaskRetryPolicy(luigi.Task):
    """
    Wrapper class for some error and success tasks. Worker won't be shutdown unless there is
    pending tasks or failed tasks which will be retried. While keep-alive is active, workers
    are not shutdown while there is/are some pending task(s).

    """

    task_namespace = "examples"

    def requires(self):
        return [ErrorTask1(), ErrorTask2(), SuccessTask1(), DynamicErrorTaskSubmitter()]

    def output(self):
        return luigi.LocalTarget(path="/tmp/_docs-%s.ldj" % self.task_id)


class ErrorTask1(luigi.Task):
    """
    This error class raises error to retry the task. retry-count for this task is 5. It can be seen on
    """

    retry = 0

    retry_count = 5

    def run(self):
        self.retry += 1
        raise Exception("Test Exception. Retry Index %s for %s" % (self.retry, self.task_family))

    def output(self):
        return luigi.LocalTarget(path="/tmp/_docs-%s.ldj" % self.task_id)


class ErrorTask2(luigi.Task):
    """
    This error class raises error to retry the task. retry-count for this task is 2
    """

    retry = 0

    retry_count = 2

    def run(self):
        self.retry += 1
        raise Exception("Test Exception. Retry Index %s for %s" % (self.retry, self.task_family))

    def output(self):
        return luigi.LocalTarget(path="/tmp/_docs-%s.ldj" % self.task_id)


class DynamicErrorTaskSubmitter(luigi.Task):
    target = None

    def run(self):
        target = yield DynamicErrorTask1()

        if target.exists():
            with self.output().open("w") as output:
                output.write("SUCCESS DynamicErrorTaskSubmitter\n")

    def output(self):
        return luigi.LocalTarget(path="/tmp/_docs-%s.ldj" % self.task_id)


class DynamicErrorTask1(luigi.Task):
    """
    This dynamic error task raises error to retry the task. retry-count for this task is 3
    """

    retry = 0

    retry_count = 3

    def run(self):
        self.retry += 1
        raise Exception("Test Exception. Retry Index %s for %s" % (self.retry, self.task_family))

    def output(self):
        return luigi.LocalTarget(path="/tmp/_docs-%s.ldj" % self.task_id)


class SuccessTask1(luigi.Task):
    def requires(self):
        return [SuccessSubTask1()]

    def run(self):
        with self.output().open("w") as output:
            output.write("SUCCESS Test Task 4\n")

    def output(self):
        return luigi.LocalTarget(path="/tmp/_docs-%s.ldj" % self.task_id)


class SuccessSubTask1(luigi.Task):
    """
    This success task sleeps for a while and then it is completed successfully.
    """

    def run(self):
        with self.output().open("w") as output:
            output.write("SUCCESS Test Task 4.1\n")

    def output(self):
        return luigi.LocalTarget(path="/tmp/_docs-%s.ldj" % self.task_id)


================================================
FILE: examples/pyspark_wc.py
================================================
# -*- coding: utf-8 -*-
#
# Copyright 2012-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import luigi
from luigi.contrib.s3 import S3Target
from luigi.contrib.spark import PySparkTask, SparkSubmitTask


class InlinePySparkWordCount(PySparkTask):
    """
    This task runs a :py:class:`luigi.contrib.spark.PySparkTask` task
    over the target data in :py:meth:`wordcount.input` (a file in S3) and
    writes the result into its :py:meth:`wordcount.output` target (a file in S3).

    This class uses :py:meth:`luigi.contrib.spark.PySparkTask.main`.

    Example luigi configuration::

        [spark]
        spark-submit: /usr/local/spark/bin/spark-submit
        master: spark://spark.example.org:7077
        # py-packages: numpy, pandas

    """

    driver_memory = "2g"
    executor_memory = "3g"

    def input(self):
        return S3Target("s3n://bucket.example.org/wordcount.input")

    def output(self):
        return S3Target("s3n://bucket.example.org/wordcount.output")

    def main(self, sc, *args):
        sc.textFile(self.input().path).flatMap(lambda line: line.split()).map(lambda word: (word, 1)).reduceByKey(lambda a, b: a + b).saveAsTextFile(
            self.output().path
        )


class PySparkWordCount(SparkSubmitTask):
    """
    This task is the same as :py:class:`InlinePySparkWordCount` above but uses
    an external python driver file specified in :py:meth:`app`

    It runs a :py:class:`luigi.contrib.spark.SparkSubmitTask` task
    over the target data in :py:meth:`wordcount.input` (a file in S3) and
    writes the result into its :py:meth:`wordcount.output` target (a file in S3).

    This class uses :py:meth:`luigi.contrib.spark.SparkSubmitTask.run`.

    Example luigi configuration::

        [spark]
        spark-submit: /usr/local/spark/bin/spark-submit
        master: spark://spark.example.org:7077
        deploy-mode: client

    """

    driver_memory = "2g"
    executor_memory = "3g"
    total_executor_cores = luigi.IntParameter(default=100, significant=False)

    name = "PySpark Word Count"
    app = "wordcount.py"

    def app_options(self):
        # These are passed to the Spark main args in the defined order.
        return [self.input().path, self.output().path]

    def input(self):
        return S3Target("s3n://bucket.example.org/wordcount.input")

    def output(self):
        return S3Target("s3n://bucket.example.org/wordcount.output")


"""
// Corresponding example Spark Job, running Word count with Spark's Python API
// This file would have to be saved into wordcount.py

import sys
from pyspark import SparkContext

if __name__ == "__main__":

    sc = SparkContext()
    sc.textFile(sys.argv[1]) \
      .flatMap(lambda line: line.split()) \
      .map(lambda word: (word, 1)) \
      .reduceByKey(lambda a, b: a + b) \
      .saveAsTextFile(sys.argv[2])
"""


================================================
FILE: examples/spark_als.py
================================================
# -*- coding: utf-8 -*-
#
# Copyright 2012-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import random

import luigi
import luigi.contrib.hdfs
import luigi.format
from luigi.contrib.spark import SparkSubmitTask


class UserItemMatrix(luigi.Task):
    #: the size of the data being generated
    data_size = luigi.IntParameter()

    def run(self):
        """
        Generates :py:attr:`~.UserItemMatrix.data_size` elements.
        Writes this data in \\ separated value format into the target :py:func:`~/.UserItemMatrix.output`.

        The data has the following elements:

        * `user` is the default Elasticsearch id field,
        * `track`: the text,
        * `rating`: the day when the data was created.

        """
        w = self.output().open("w")
        for user in range(self.data_size):
            track = int(random.random() * self.data_size)
            w.write("%d\\%d\\%f" % (user, track, 1.0))
        w.close()

    def output(self):
        """
        Returns the target output for this task.
        In this case, a successful execution of this task will create a file in HDFS.

        :return: the target output for this task.
        :rtype: object (:py:class:`~luigi.target.Target`)
        """
        return luigi.contrib.hdfs.HdfsTarget("data-matrix", format=luigi.format.Gzip)


class SparkALS(SparkSubmitTask):
    """
    This task runs a :py:class:`luigi.contrib.spark.SparkSubmitTask` task
    over the target data returned by :py:meth:`~/.UserItemMatrix.output` and
    writes the result into its :py:meth:`~.SparkALS.output` target (a file in HDFS).

    This class uses :py:meth:`luigi.contrib.spark.SparkSubmitTask.run`.

    Example luigi configuration::

        [spark]
        spark-submit: /usr/local/spark/bin/spark-submit
        master: yarn-client

    """

    data_size = luigi.IntParameter(default=1000)

    driver_memory = "2g"
    executor_memory = "3g"
    num_executors = luigi.IntParameter(default=100)

    app = "my-spark-assembly.jar"
    entry_class = "com.spotify.spark.ImplicitALS"

    def app_options(self):
        # These are passed to the Spark main args in the defined order.
        return [self.input().path, self.output().path]

    def requires(self):
        """
        This task's dependencies:

        * :py:class:`~.UserItemMatrix`

        :return: object (:py:class:`luigi.task.Task`)
        """
        return UserItemMatrix(self.data_size)

    def output(self):
        """
        Returns the target output for this task.
        In this case, a successful execution of this task will create a file in HDFS.

        :return: the target output for this task.
        :rtype: object (:py:class:`~luigi.target.Target`)
        """
        # The corresponding Spark job outputs as GZip format.
        return luigi.contrib.hdfs.HdfsTarget("als-output/", format=luigi.format.Gzip)


"""
// Corresponding example Spark Job, a wrapper around the MLLib ALS job.
// This class would have to be jarred into my-spark-assembly.jar
// using sbt assembly (or package) and made available to the Luigi job
// above.

package com.spotify.spark

import org.apache.spark._
import org.apache.spark.mllib.recommendation.{Rating, ALS}
import org.apache.hadoop.io.compress.GzipCodec

object ImplicitALS {

  def main(args: Array[String]) {
    val sc = new SparkContext(args(0), "ImplicitALS")
    val input = args(1)
    val output = args(2)

    val ratings = sc.textFile(input)
      .map { l: String =>
        val t = l.split('\t')
        Rating(t(0).toInt, t(1).toInt, t(2).toFloat)
      }

    val model = ALS.trainImplicit(ratings, 40, 20, 0.8, 150)
    model
      .productFeatures
      .map { case (id, vec) =>
        id + "\t" + vec.map(d => "%.6f".format(d)).mkString(" ")
      }
      .saveAsTextFile(output, classOf[GzipCodec])

    sc.stop()
  }
}
"""


================================================
FILE: examples/ssh_remote_execution.py
================================================
# -*- coding: utf-8 -*-
#
# Copyright 2012-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from collections import defaultdict

import luigi
from luigi.contrib.ssh import RemoteContext, RemoteTarget
from luigi.mock import MockTarget

SSH_HOST = "some.accessible.host"


class CreateRemoteData(luigi.Task):
    """
    Dump info on running processes on remote host.
    Data is still stored on the remote host
    """

    def output(self):
        """
        Returns the target output for this task.
        In this case, a successful execution of this task will create a file on a remote server using SSH.

        :return: the target output for this task.
        :rtype: object (:py:class:`~luigi.target.Target`)
        """
        return RemoteTarget("/tmp/stuff", SSH_HOST)

    def run(self):
        remote = RemoteContext(SSH_HOST)
        print(remote.check_output(["ps aux > {0}".format(self.output().path)]))


class ProcessRemoteData(luigi.Task):
    """
    Create a toplist of users based on how many running processes they have on a remote machine.

    In this example the processed data is stored in a MockTarget.
    """

    def requires(self):
        """
        This task's dependencies:

        * :py:class:`~.CreateRemoteData`

        :return: object (:py:class:`luigi.task.Task`)
        """
        return CreateRemoteData()

    def run(self):
        processes_per_user = defaultdict(int)
        with self.input().open("r") as infile:
            for line in infile:
                username = line.split()[0]
                processes_per_user[username] += 1

        toplist = sorted(processes_per_user.items(), key=lambda x: x[1], reverse=True)

        with self.output().open("w") as outfile:
            for user, n_processes in toplist:
                print(n_processes, user, file=outfile)

    def output(self):
        """
        Returns the target output for this task.
        In this case, a successful execution of this task will simulate the creation of a file in a filesystem.

        :return: the target output for this task.
        :rtype: object (:py:class:`~luigi.target.Target`)
        """
        return MockTarget("output", mirror_on_stderr=True)


================================================
FILE: examples/terasort.py
================================================
# -*- coding: utf-8 -*-
#
# Copyright 2012-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import logging
import os

import luigi
import luigi.contrib.hadoop_jar
import luigi.contrib.hdfs

logger = logging.getLogger("luigi-interface")


def hadoop_examples_jar():
    config = luigi.configuration.get_config()
    examples_jar = config.get("hadoop", "examples-jar")
    if not examples_jar:
        logger.error("You must specify hadoop:examples-jar in luigi.cfg")
        raise
    if not os.path.exists(examples_jar):
        logger.error("Can't find example jar: " + examples_jar)
        raise
    return examples_jar


DEFAULT_TERASORT_IN = "/tmp/terasort-in"
DEFAULT_TERASORT_OUT = "/tmp/terasort-out"


class TeraGen(luigi.contrib.hadoop_jar.HadoopJarJobTask):
    """
    Runs TeraGen, by default with 1TB of data (10B records)
    """

    records = luigi.Parameter(default="10000000000", description="Number of records, each record is 100 Bytes")
    terasort_in = luigi.Parameter(default=DEFAULT_TERASORT_IN, description="directory to store terasort input into.")

    def output(self):
        """
        Returns the target output for this task.
        In this case, a successful execution of this task will create a file in HDFS.

        :return: the target output for this task.
        :rtype: object (:py:class:`~luigi.target.Target`)
        """
        return luigi.contrib.hdfs.HdfsTarget(self.terasort_in)

    def jar(self):
        return hadoop_examples_jar()

    def main(self):
        return "teragen"

    def args(self):
        # First arg is 10B -- each record is 100bytes
        return [self.records, self.output()]


class TeraSort(luigi.contrib.hadoop_jar.HadoopJarJobTask):
    """
    Runs TeraGent, by default using
    """

    terasort_in = luigi.Parameter(default=DEFAULT_TERASORT_IN, description="directory to store terasort input into.")
    terasort_out = luigi.Parameter(default=DEFAULT_TERASORT_OUT, description="directory to store terasort output into.")

    def requires(self):
        """
        This task's dependencies:

        * :py:class:`~.TeraGen`

        :return: object (:py:class:`luigi.task.Task`)
        """
        return TeraGen(terasort_in=self.terasort_in)

    def output(self):
        """
        Returns the target output for this task.
        In this case, a successful execution of this task will create a file in HDFS.

        :return: the target output for this task.
        :rtype: object (:py:class:`~luigi.target.Target`)
        """
        return luigi.contrib.hdfs.HdfsTarget(self.terasort_out)

    def jar(self):
        return hadoop_examples_jar()

    def main(self):
        return "terasort"

    def args(self):
        return [self.input(), self.output()]


================================================
FILE: examples/top_artists.py
================================================
# -*- coding: utf-8 -*-
#
# Copyright 2012-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import random
from collections import defaultdict
from heapq import nlargest

import luigi
import luigi.contrib.hdfs
import luigi.contrib.postgres
import luigi.contrib.spark


class ExternalStreams(luigi.ExternalTask):
    """
    Example of a possible external data dump

    To depend on external targets (typically at the top of your dependency graph), you can define
    an ExternalTask like this.
    """

    date = luigi.DateParameter()

    def output(self):
        """
        Returns the target output for this task.
        In this case, it expects a file to be present in HDFS.

        :return: the target output for this task.
        :rtype: object (:py:class:`luigi.target.Target`)
        """
        return luigi.contrib.hdfs.HdfsTarget(self.date.strftime("data/streams_%Y-%m-%d.tsv"))


class Streams(luigi.Task):
    """
    Faked version right now, just generates bogus data.
    """

    date = luigi.DateParameter()

    def run(self):
        """
        Generates bogus data and writes it into the :py:meth:`~.Streams.output` target.
        """
        with self.output().open("w") as output:
            for _ in range(1000):
                output.write("{} {} {}\n".format(random.randint(0, 999), random.randint(0, 999), random.randint(0, 999)))

    def output(self):
        """
        Returns the target output for this task.
        In this case, a successful execution of this task will create a file in the local file system.

        :return: the target output for this task.
        :rtype: object (:py:class:`luigi.target.Target`)
        """
        return luigi.LocalTarget(self.date.strftime("data/streams_%Y_%m_%d_faked.tsv"))


class StreamsHdfs(Streams):
    """
    This task performs the same work as :py:class:`~.Streams` but its output is written to HDFS.

    This class uses :py:meth:`~.Streams.run` and
    overrides :py:meth:`~.Streams.output` so redefine HDFS as its target.
    """

    def output(self):
        """
        Returns the target output for this task.
        In this case, a successful execution of this task will create a file in HDFS.

        :return: the target output for this task.
        :rtype: object (:py:class:`luigi.target.Target`)
        """
        return luigi.contrib.hdfs.HdfsTarget(self.date.strftime("data/streams_%Y_%m_%d_faked.tsv"))


class AggregateArtists(luigi.Task):
    """
    This task runs over the target data returned by :py:meth:`~/.Streams.output` and
    writes the result into its :py:meth:`~.AggregateArtists.output` target (local file).
    """

    date_interval = luigi.DateIntervalParameter()

    def output(self):
        """
        Returns the target output for this task.
        In this case, a successful execution of this task will create a file on the local filesystem.

        :return: the target output for this task.
        :rtype: object (:py:class:`luigi.target.Target`)
        """
        return luigi.LocalTarget("data/artist_streams_{}.tsv".format(self.date_interval))

    def requires(self):
        """
        This task's dependencies:

        * :py:class:`~.Streams`

        :return: list of object (:py:class:`luigi.task.Task`)
        """
        return [Streams(date) for date in self.date_interval]

    def run(self):
        artist_count = defaultdict(int)

        for t in self.input():
            with t.open("r") as in_file:
                for line in in_file:
                    _, artist, track = line.strip().split()
                    artist_count[artist] += 1

        with self.output().open("w") as out_file:
            for artist, count in artist_count.items():
                out_file.write("{}\t{}\n".format(artist, count))


class AggregateArtistsSpark(luigi.contrib.spark.SparkSubmitTask):
    """
    This task runs a :py:class:`luigi.contrib.spark.SparkSubmitTask` task
    over each target data returned by :py:meth:`~/.StreamsHdfs.output` and
    writes the result into its :py:meth:`~.AggregateArtistsSpark.output` target (a file in HDFS).
    """

    date_interval = luigi.DateIntervalParameter()

    """
    The Pyspark script to run.

    For Spark applications written in Java or Scala, the name of a jar file should be supplied instead.
    """
    app = "top_artists_spark.py"

    """
    Address of the Spark cluster master. In this case, we are not using a cluster, but running
    Spark in local mode.
    """
    master = "local[*]"

    def output(self):
        """
        Returns the target output for this task.
        In this case, a successful execution of this task will create a file in HDFS.

        :return: the target output for this task.
        :rtype: object (:py:class:`luigi.target.Target`)
        """
        return luigi.contrib.hdfs.HdfsTarget("data/artist_streams_%s.tsv" % self.date_interval)

    def requires(self):
        """
        This task's dependencies:

        * :py:class:`~.StreamsHdfs`

        :return: list of object (:py:class:`luigi.task.Task`)
        """
        return [StreamsHdfs(date) for date in self.date_interval]

    def app_options(self):
        # :func:`~luigi.task.Task.input` returns the targets produced by the tasks in
        # `~luigi.task.Task.requires`.
        return [",".join([p.path for p in self.input()]), self.output().path]


class Top10Artists(luigi.Task):
    """
    This task runs over the target data returned by :py:meth:`~/.AggregateArtists.output` or
    :py:meth:`~/.AggregateArtistsSpark.output` in case :py:attr:`~/.Top10Artists.use_spark` is set and
    writes the result into its :py:meth:`~.Top10Artists.output` target (a file in local filesystem).
    """

    date_interval = luigi.DateIntervalParameter()
    use_spark = luigi.BoolParameter()

    def requires(self):
        """
        This task's dependencies:

        * :py:class:`~.AggregateArtists` or
        * :py:class:`~.AggregateArtistsSpark` if :py:attr:`~/.Top10Artists.use_spark` is set.

        :return: object (:py:class:`luigi.task.Task`)
        """
        if self.use_spark:
            return AggregateArtistsSpark(self.date_interval)
        else:
            return AggregateArtists(self.date_interval)

    def output(self):
        """
        Returns the target output for this task.
        In this case, a successful execution of this task will create a file on the local filesystem.

        :return: the target output for this task.
        :rtype: object (:py:class:`luigi.target.Target`)
        """
        return luigi.LocalTarget("data/top_artists_%s.tsv" % self.date_interval)

    def run(self):
        top_10 = nlargest(10, self._input_iterator())
        with self.output().open("w") as out_file:
            for streams, artist in top_10:
                out_line = "\t".join([str(self.date_interval.date_a), str(self.date_interval.date_b), artist, str(streams)])
                out_file.write((out_line + "\n"))

    def _input_iterator(self):
        with self.input().open("r") as in_file:
            for line in in_file:
                artist, streams = line.strip().split()
                yield int(streams), artist


class ArtistToplistToDatabase(luigi.contrib.postgres.CopyToTable):
    """
    This task runs a :py:class:`luigi.contrib.postgres.CopyToTable` task
    over the target data returned by :py:meth:`~/.Top10Artists.output` and
    writes the result into its :py:meth:`~.ArtistToplistToDatabase.output` target which,
    by default, is :py:class:`luigi.contrib.postgres.PostgresTarget` (a table in PostgreSQL).

    This class uses :py:meth:`luigi.contrib.postgres.CopyToTable.run`
    and :py:meth:`luigi.contrib.postgres.CopyToTable.output`.
    """

    date_interval = luigi.DateIntervalParameter()
    use_spark = luigi.BoolParameter()

    host = "localhost"
    database = "toplists"
    user = "luigi"
    password = "abc123"  # ;)
    table = "top10"

    columns = [("date_from", "DATE"), ("date_to", "DATE"), ("artist", "TEXT"), ("streams", "INT")]

    def requires(self):
        """
        This task's dependencies:

        * :py:class:`~.Top10Artists`

        :return: list of object (:py:class:`luigi.task.Task`)
        """
        return Top10Artists(self.date_interval, self.use_spark)


if __name__ == "__main__":
    luigi.run()


================================================
FILE: examples/top_artists_spark.py
================================================
# -*- coding: utf-8 -*-

import operator
import sys

from pyspark.sql import SparkSession


def main(argv):
    input_paths = argv[1].split(",")
    output_path = argv[2]

    spark = SparkSession.builder.getOrCreate()

    streams = spark.read.option("sep", "\t").csv(input_paths[0])
    for stream_path in input_paths[1:]:
        streams.union(spark.read.option("sep", "\t").csv(stream_path))

    # The second field is the artist
    counts = streams.map(lambda row: (row[1], 1)).reduceByKey(operator.add)

    counts.write.option("sep", "\t").csv(output_path)


if __name__ == "__main__":
    sys.exit(main(sys.argv))


================================================
FILE: examples/wordcount.py
================================================
# -*- coding: utf-8 -*-
#
# Copyright 2012-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import luigi


class InputText(luigi.ExternalTask):
    """
    This class represents something that was created elsewhere by an external process,
    so all we want to do is to implement the output method.
    """

    date = luigi.DateParameter()

    def output(self):
        """
        Returns the target output for this task.
        In this case, it expects a file to be present in the local file system.

        :return: the target output for this task.
        :rtype: object (:py:class:`luigi.target.Target`)
        """
        return luigi.LocalTarget(self.date.strftime("/var/tmp/text/%Y-%m-%d.txt"))


class WordCount(luigi.Task):
    date_interval = luigi.DateIntervalParameter()

    def requires(self):
        """
        This task's dependencies:

        * :py:class:`~.InputText`

        :return: list of object (:py:class:`luigi.task.Task`)
        """
        return [InputText(date) for date in self.date_interval.dates()]

    def output(self):
        """
        Returns the target output for this task.
        In this case, a successful execution of this task will create a file on the local filesystem.

        :return: the target output for this task.
        :rtype: object (:py:class:`luigi.target.Target`)
        """
        return luigi.LocalTarget("/var/tmp/text-count/%s" % self.date_interval)

    def run(self):
        """
        1. count the words for each of the :py:meth:`~.InputText.output` targets created by :py:class:`~.InputText`
        2. write the count into the :py:meth:`~.WordCount.output` target
        """
        count = {}

        # NOTE: self.input() actually returns an element for the InputText.output() target
        for f in self.input():  # The input() method is a wrapper around requires() that returns Target objects
            for line in f.open("r"):  # Target objects are a file system/format abstraction and this will return a file stream object
                for word in line.strip().split():
                    count[word] = count.get(word, 0) + 1

        # output data
        f = self.output().open("w")
        for word, count in count.items():
            f.write("%s\t%d\n" % (word, count))
        f.close()  # WARNING: file system operations are atomic therefore if you don't close the file you lose all data


================================================
FILE: examples/wordcount_hadoop.py
================================================
# -*- coding: utf-8 -*-
#
# Copyright 2012-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import luigi
import luigi.contrib.hadoop
import luigi.contrib.hdfs

# To make this run, you probably want to edit /etc/luigi/client.cfg and add something like:
#
# [hadoop]
# jar: /usr/lib/hadoop-xyz/hadoop-streaming-xyz-123.jar


class InputText(luigi.ExternalTask):
    """
    This task is a :py:class:`luigi.task.ExternalTask` which means it doesn't generate the
    :py:meth:`~.InputText.output` target on its own instead relying on the execution something outside of Luigi
    to produce it.
    """

    date = luigi.DateParameter()

    def output(self):
        """
        Returns the target output for this task.
        In this case, it expects a file to be present in HDFS.

        :return: the target output for this task.
        :rtype: object (:py:class:`luigi.target.Target`)
        """
        return luigi.contrib.hdfs.HdfsTarget(self.date.strftime("/tmp/text/%Y-%m-%d.txt"))


class WordCount(luigi.contrib.hadoop.JobTask):
    """
    This task runs a :py:class:`luigi.contrib.hadoop.JobTask`
    over the target data returned by :py:meth:`~/.InputText.output` and
    writes the result into its :py:meth:`~.WordCount.output` target.

    This class uses :py:meth:`luigi.contrib.hadoop.JobTask.run`.
    """

    date_interval = luigi.DateIntervalParameter()

    def requires(self):
        """
        This task's dependencies:

        * :py:class:`~.InputText`

        :return: list of object (:py:class:`luigi.task.Task`)
        """
        return [InputText(date) for date in self.date_interval.dates()]

    def output(self):
        """
        Returns the target output for this task.
        In this case, a successful execution of this task will create a file in HDFS.

        :return: the target output for this task.
        :rtype: object (:py:class:`luigi.target.Target`)
        """
        return luigi.contrib.hdfs.HdfsTarget("/tmp/text-count/%s" % self.date_interval)

    def mapper(self, line):
        for word in line.strip().split():
            yield word, 1

    def reducer(self, key, values):
        yield key, sum(values)


if __name__ == "__main__":
    luigi.run()


================================================
FILE: luigi/__init__.py
================================================
# -*- coding: utf-8 -*-
#
# Copyright 2012-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Package containing core luigi functionality.
"""

from luigi import configuration, event, interface, local_target, parameter, rpc, target, task
from luigi.__version__ import VERSION
from luigi.event import Event
from luigi.execution_summary import LuigiStatusCode
from luigi.interface import build, run
from luigi.local_target import LocalTarget
from luigi.parameter import (
    BoolParameter,
    ChoiceListParameter,
    ChoiceParameter,
    DateHourParameter,
    DateIntervalParameter,
    DateMinuteParameter,
    DateParameter,
    DateSecondParameter,
    DictParameter,
    EnumListParameter,
    EnumParameter,
    FloatParameter,
    IntParameter,
    ListParameter,
    MonthParameter,
    NumericalParameter,
    OptionalBoolParameter,
    OptionalChoiceParameter,
    OptionalDictParameter,
    OptionalFloatParameter,
    OptionalIntParameter,
    OptionalListParameter,
    OptionalNumericalParameter,
    OptionalParameter,
    OptionalPathParameter,
    OptionalStrParameter,
    OptionalTupleParameter,
    Parameter,
    PathParameter,
    StrParameter,
    TaskParameter,
    TimeDeltaParameter,
    TupleParameter,
    YearParameter,
)
from luigi.rpc import RemoteScheduler, RPCError
from luigi.target import Target
from luigi.task import (
    Config,
    DynamicRequirements,
    ExternalTask,
    Task,
    WrapperTask,
    auto_namespace,
    namespace,
)

__version__ = VERSION
__all__ = [
    "task",
    "Task",
    "Config",
    "ExternalTask",
    "WrapperTask",
    "namespace",
    "auto_namespace",
    "DynamicRequirements",
    "target",
    "Target",
    "LocalTarget",
    "rpc",
    "RemoteScheduler",
    "RPCError",
    "parameter",
    "Parameter",
    "DateParameter",
    "MonthParameter",
    "YearParameter",
    "DateHourParameter",
    "DateMinuteParameter",
    "DateSecondParameter",
    "DateIntervalParameter",
    "TimeDeltaParameter",
    "StrParameter",
    "IntParameter",
    "FloatParameter",
    "BoolParameter",
    "PathParameter",
    "TaskParameter",
    "ListParameter",
    "TupleParameter",
    "EnumParameter",
    "DictParameter",
    "EnumListParameter",
    "configuration",
    "interface",
    "local_target",
    "run",
    "build",
    "event",
    "Event",
    "NumericalParameter",
    "ChoiceParameter",
    "ChoiceListParameter",
    "OptionalParameter",
    "OptionalStrParameter",
    "OptionalIntParameter",
    "OptionalFloatParameter",
    "OptionalBoolParameter",
    "OptionalPathParameter",
    "OptionalDictParameter",
    "OptionalListParameter",
    "OptionalTupleParameter",
    "OptionalChoiceParameter",
    "OptionalNumericalParameter",
    "LuigiStatusCode",
    "__version__",
]

if not configuration.get_config().has_option("core", "autoload_range"):
    import warnings

    warning_message = """
        Autoloading range tasks by default has been deprecated and will be removed in a future version.
        To get the behavior now add an option to luigi.cfg:

          [core]
            autoload_range: false

        Alternately set the option to true to continue with existing behaviour and suppress this warning.
    """
    warnings.warn(warning_message, DeprecationWarning)

if configuration.get_config().getboolean("core", "autoload_range", True):
    from .tools import range  # noqa: F401    just makes the tool classes available from command line

    __all__.append("range")


================================================
FILE: luigi/__main__.py
================================================
# -*- coding: utf-8 -*-
#
# Copyright 2012-2016 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from luigi.cmdline import luigi_run

if __name__ == "__main__":
    luigi_run()


================================================
FILE: luigi/__version__.py
================================================
# coding: utf-8

VERSION = "3.8.0"


================================================
FILE: luigi/batch_notifier.py
================================================
"""
Library for sending batch notifications from the Luigi scheduler. This module
is internal to Luigi and not designed for use in other contexts.
"""

import collections
import time
from datetime import datetime

import luigi
import luigi.parameter
import luigi.task
from luigi.notifications import email, send_email


class batch_email(luigi.task.Config):
    email_interval = luigi.parameter.IntParameter(
        default=60,
        config_path=dict(section="batch-notifier", name="email-interval-minutes"),
        description="Number of minutes between e-mail sends (default: 60)",
    )
    batch_mode = luigi.parameter.ChoiceParameter(
        default="unbatched_params",
        choices=("family", "all", "unbatched_params"),
        description='Method used for batching failures in e-mail. If "family" all failures for '
        'tasks with the same family will be batched. If "unbatched_params", all '
        "failures for tasks with the same family and non-batched parameters will be "
        'batched. If "all", tasks will only be batched if they have identical names.',
    )
    error_lines = luigi.parameter.IntParameter(default=20, description="Number of lines to show from each error message. 0 means show all")
    error_messages = luigi.parameter.IntParameter(default=1, description="Number of error messages to show for each group")
    group_by_error_messages = luigi.parameter.BoolParameter(default=True, description="Group items with the same error messages together")


class ExplQueue(collections.OrderedDict):
    def __init__(self, num_items):
        self.num_items = num_items
        super(ExplQueue, self).__init__()

    def enqueue(self, item):
        self.pop(item, None)
        self[item] = datetime.now()
        if len(self) > self.num_items:
            self.popitem(last=False)  # pop first item if past length


def _fail_queue(num_messages):
    return lambda: collections.defaultdict(lambda: ExplQueue(num_messages))


def _plural_format(template, number, plural="s"):
    if number == 0:
        return ""
    return template.format(number, "" if number == 1 else plural)


class BatchNotifier:
    def __init__(self, **kwargs):
        self._config = batch_email(**kwargs)
        self._fail_counts = collections.defaultdict(collections.Counter)
        self._disabled_counts = collections.defaultdict(collections.Counter)
        self._scheduling_fail_counts = collections.defaultdict(collections.Counter)
        self._fail_expls = collections.defaultdict(_fail_queue(self._config.error_messages))
        self._update_next_send()

        self._email_format = email().format
        if email().receiver:
            self._default_owner = set(filter(None, email().receiver.split(",")))
        else:
            self._default_owner = set()

    def _update_next_send(self):
        self._next_send = time.time() + 60 * self._config.email_interval

    def _key(self, task_name, family, unbatched_args):
        if self._config.batch_mode == "all":
            return task_name
        elif self._config.batch_mode == "family":
            return family
        elif self._config.batch_mode == "unbatched_params":
            param_str = ", ".join("{}={}".format(k, v) for k, v in unbatched_args.items())
            return "{}({})".format(family, param_str)
        else:
            raise ValueError("Unknown batch mode for batch notifier: {}".format(self._config.batch_mode))

    def _format_expl(self, expl):
        lines = expl.rstrip().split("\n")[-self._config.error_lines :]
        if self._email_format == "html":
            return "
{}
".format("\n".join(lines)) else: return "\n{}".format("\n".join(map(" {}".format, lines))) def _expl_body(self, expls): lines = [self._format_expl(expl) for expl in expls] if lines and self._email_format != "html": lines.append("") return "\n".join(lines) def _format_task(self, task_tuple): task, failure_count, disable_count, scheduling_count = task_tuple counts = [ _plural_format("{} failure{}", failure_count), _plural_format("{} disable{}", disable_count), _plural_format("{} scheduling failure{}", scheduling_count), ] count_str = ", ".join(filter(None, counts)) return "{} ({})".format(task, count_str) def _format_tasks(self, tasks): lines = map(self._format_task, sorted(tasks, key=self._expl_key)) if self._email_format == "html": return "
  • {}".format("\n
    ".join(lines)) else: return "- {}".format("\n ".join(lines)) def _owners(self, owners): return self._default_owner | set(owners) def add_failure(self, task_name, family, unbatched_args, expl, owners): key = self._key(task_name, family, unbatched_args) for owner in self._owners(owners): self._fail_counts[owner][key] += 1 self._fail_expls[owner][key].enqueue(expl) def add_disable(self, task_name, family, unbatched_args, owners): key = self._key(task_name, family, unbatched_args) for owner in self._owners(owners): self._disabled_counts[owner][key] += 1 self._fail_counts[owner].setdefault(key, 0) def add_scheduling_fail(self, task_name, family, unbatched_args, expl, owners): key = self._key(task_name, family, unbatched_args) for owner in self._owners(owners): self._scheduling_fail_counts[owner][key] += 1 self._fail_expls[owner][key].enqueue(expl) self._fail_counts[owner].setdefault(key, 0) def _task_expl_groups(self, expls): if not self._config.group_by_error_messages: return [((task,), msg) for task, msg in expls.items()] groups = collections.defaultdict(list) for task, msg in expls.items(): groups[msg].append(task) return [(tasks, msg) for msg, tasks in groups.items()] def _expls_key(self, expls_tuple): expls = expls_tuple[0] num_failures = sum(failures + scheduling_fails for (_1, failures, _2, scheduling_fails) in expls) num_disables = sum(disables for (_1, _2, disables, _3) in expls) min_name = min(expls)[0] return -num_failures, -num_disables, min_name def _expl_key(self, expl): return self._expls_key(((expl,), None)) def _email_body(self, fail_counts, disable_counts, scheduling_counts, fail_expls): expls = { (name, fail_count, disable_counts[name], scheduling_counts[name]): self._expl_body(fail_expls[name]) for name, fail_count in fail_counts.items() } expl_groups = sorted(self._task_expl_groups(expls), key=self._expls_key) body_lines = [] for tasks, msg in expl_groups: body_lines.append(self._format_tasks(tasks)) body_lines.append(msg) body = "\n".join(filter(None, body_lines)).rstrip() if self._email_format == "html": return "
      \n{}\n
    ".format(body) else: return body def _send_email(self, fail_counts, disable_counts, scheduling_counts, fail_expls, owner): num_failures = sum(fail_counts.values()) num_disables = sum(disable_counts.values()) num_scheduling_failures = sum(scheduling_counts.values()) subject_parts = [ _plural_format("{} failure{}", num_failures), _plural_format("{} disable{}", num_disables), _plural_format("{} scheduling failure{}", num_scheduling_failures), ] subject_base = ", ".join(filter(None, subject_parts)) if subject_base: prefix = "" if owner in self._default_owner else "Your tasks have " subject = "Luigi: {}{} in the last {} minutes".format(prefix, subject_base, self._config.email_interval) email_body = self._email_body(fail_counts, disable_counts, scheduling_counts, fail_expls) send_email(subject, email_body, email().sender, (owner,)) def send_email(self): try: for owner, failures in self._fail_counts.items(): self._send_email( fail_counts=failures, disable_counts=self._disabled_counts[owner], scheduling_counts=self._scheduling_fail_counts[owner], fail_expls=self._fail_expls[owner], owner=owner, ) finally: self._update_next_send() self._fail_counts.clear() self._disabled_counts.clear() self._scheduling_fail_counts.clear() self._fail_expls.clear() def update(self): if time.time() >= self._next_send: self.send_email() ================================================ FILE: luigi/cmdline.py ================================================ import argparse import sys from luigi.retcodes import run_with_retcodes from luigi.setup_logging import DaemonLogging def luigi_run(argv=sys.argv[1:]): run_with_retcodes(argv) def luigid(argv=sys.argv[1:]): import luigi.configuration import luigi.process import luigi.server parser = argparse.ArgumentParser(description="Central luigi server") parser.add_argument("--background", help="Run in background mode", action="store_true") parser.add_argument("--pidfile", help="Write pidfile") parser.add_argument("--logdir", help="log directory") parser.add_argument("--state-path", help="Pickled state file") parser.add_argument("--address", help="Listening interface") parser.add_argument("--unix-socket", help="Unix socket path") parser.add_argument("--port", default=8082, help="Listening port") opts = parser.parse_args(argv) if opts.state_path: config = luigi.configuration.get_config() config.set("scheduler", "state_path", opts.state_path) DaemonLogging.setup(opts) if opts.background: luigi.process.daemonize( luigi.server.run, api_port=opts.port, address=opts.address, pidfile=opts.pidfile, logdir=opts.logdir, unix_socket=opts.unix_socket ) else: luigi.server.run(api_port=opts.port, address=opts.address, unix_socket=opts.unix_socket) ================================================ FILE: luigi/cmdline_parser.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ This module contains luigi internal parsing logic. Things exposed here should be considered internal to luigi. """ import argparse import sys from contextlib import contextmanager from luigi.task_register import Register class CmdlineParser: """ Helper for parsing command line arguments and used as part of the context when instantiating task objects. Normal luigi users should just use :py:func:`luigi.run`. """ _instance = None @classmethod def get_instance(cls): """Singleton getter""" return cls._instance @classmethod @contextmanager def global_instance(cls, cmdline_args, allow_override=False): """ Meant to be used as a context manager. """ orig_value = cls._instance assert (orig_value is None) or allow_override new_value = None try: new_value = CmdlineParser(cmdline_args) cls._instance = new_value yield new_value finally: assert cls._instance is new_value cls._instance = orig_value def __init__(self, cmdline_args): """ Initialize cmd line args """ known_args, _ = self._build_parser().parse_known_args(args=cmdline_args) self._attempt_load_module(known_args) # We have to parse again now. As the positionally first unrecognized # argument (the task) could be different. known_args, _ = self._build_parser().parse_known_args(args=cmdline_args) root_task = known_args.root_task parser = self._build_parser(root_task=root_task, help_all=known_args.core_help_all) self._possibly_exit_with_help(parser, known_args) if not root_task: raise SystemExit("No task specified") else: # Check that what we believe to be the task is correctly spelled Register.get_task_cls(root_task) known_args = parser.parse_args(args=cmdline_args) self.known_args = known_args # Also publicly expose parsed arguments @staticmethod def _build_parser(root_task=None, help_all=False): parser = argparse.ArgumentParser(add_help=False) # Unfortunately, we have to set it as optional to argparse, so we can # parse out stuff like `--module` before we call for `--help`. parser.add_argument( "root_task", nargs="?", help="Task family to run. Is not optional.", metavar="Required root task", ) for task_name, is_without_section, param_name, param_obj in Register.get_all_params(): is_the_root_task = task_name == root_task help = param_obj.description if any((is_the_root_task, help_all, param_obj.always_in_help)) else argparse.SUPPRESS flag_name_underscores = param_name if is_without_section else task_name + "_" + param_name global_flag_name = "--" + flag_name_underscores.replace("_", "-") parser.add_argument(global_flag_name, help=help, **param_obj._parser_kwargs(param_name, task_name)) if is_the_root_task: local_flag_name = "--" + param_name.replace("_", "-") parser.add_argument(local_flag_name, help=help, **param_obj._parser_kwargs(param_name)) return parser def get_task_obj(self): """ Get the task object """ return self._get_task_cls()(**self._get_task_kwargs()) def _get_task_cls(self): """ Get the task class """ return Register.get_task_cls(self.known_args.root_task) def _get_task_kwargs(self): """ Get the local task arguments as a dictionary. The return value is in the form ``dict(my_param='my_value', ...)`` """ res = {} for param_name, param_obj in self._get_task_cls().get_params(): attr = getattr(self.known_args, param_name) if attr: res.update(((param_name, param_obj.parse(attr)),)) return res @staticmethod def _attempt_load_module(known_args): """ Load the --module parameter """ module = known_args.core_module if module: __import__(module) @staticmethod def _possibly_exit_with_help(parser, known_args): """ Check if the user passed --help[-all], if so, print a message and exit. """ if known_args.core_help or known_args.core_help_all: parser.print_help() sys.exit() ================================================ FILE: luigi/configuration/__init__.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from .cfg_parser import LuigiConfigParser from .core import add_config_path, get_config from .toml_parser import LuigiTomlParser __all__ = [ "add_config_path", "get_config", "LuigiConfigParser", "LuigiTomlParser", ] ================================================ FILE: luigi/configuration/base_parser.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import logging # IMPORTANT: don't inherit from `object`! # ConfigParser have some troubles in this case. # More info: https://stackoverflow.com/a/19323238 class BaseParser: @classmethod def instance(cls, *args, **kwargs): """Singleton getter""" if cls._instance is None: cls._instance = cls(*args, **kwargs) loaded = cls._instance.reload() logging.getLogger("luigi-interface").info("Loaded %r", loaded) return cls._instance @classmethod def add_config_path(cls, path): cls._config_paths.append(path) cls.reload() @classmethod def reload(cls): return cls.instance().read(cls._config_paths) ================================================ FILE: luigi/configuration/cfg_parser.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ luigi.configuration provides some convenience wrappers around Python's ConfigParser to get configuration options from config files. The default location for configuration files is luigi.cfg (or client.cfg) in the current working directory, then /etc/luigi/client.cfg. Configuration has largely been superseded by parameters since they can do essentially everything configuration can do, plus a tighter integration with the rest of Luigi. See :doc:`/configuration` for more info. """ import os import re import warnings from configparser import BasicInterpolation, ConfigParser, Interpolation, InterpolationError, NoOptionError, NoSectionError from .base_parser import BaseParser class InterpolationMissingEnvvarError(InterpolationError): """ Raised when option value refers to a nonexisting environment variable. """ def __init__(self, option, section, value, envvar): msg = ("Config refers to a nonexisting environment variable {}. Section [{}], option {}={}").format(envvar, section, option, value) InterpolationError.__init__(self, option, section, msg) class EnvironmentInterpolation(Interpolation): """ Custom interpolation which allows values to refer to environment variables using the ``${ENVVAR}`` syntax. """ _ENVRE = re.compile(r"\$\{([^}]+)\}") # matches "${envvar}" def before_get(self, parser, section, option, value, defaults): return self._interpolate_env(option, section, value) def _interpolate_env(self, option, section, value): rawval = value parts = [] while value: match = self._ENVRE.search(value) if match is None: parts.append(value) break envvar = match.groups()[0] try: envval = os.environ[envvar] except KeyError: raise InterpolationMissingEnvvarError(option, section, rawval, envvar) start, end = match.span() parts.append(value[:start]) parts.append(envval) value = value[end:] return "".join(parts) class CombinedInterpolation(Interpolation): """ Custom interpolation which applies multiple interpolations in series. :param interpolations: a sequence of configparser.Interpolation objects. """ def __init__(self, interpolations): self._interpolations = interpolations def before_get(self, parser, section, option, value, defaults): for interp in self._interpolations: value = interp.before_get(parser, section, option, value, defaults) return value def before_read(self, parser, section, option, value): for interp in self._interpolations: value = interp.before_read(parser, section, option, value) return value def before_set(self, parser, section, option, value): for interp in self._interpolations: value = interp.before_set(parser, section, option, value) return value def before_write(self, parser, section, option, value): for interp in self._interpolations: value = interp.before_write(parser, section, option, value) return value class LuigiConfigParser(BaseParser, ConfigParser): NO_DEFAULT = object() enabled = True optionxform = str # type: ignore _instance = None _config_paths = [ "/etc/luigi/client.cfg", # Deprecated old-style global luigi config "/etc/luigi/luigi.cfg", "client.cfg", # Deprecated old-style local luigi config "luigi.cfg", ] _DEFAULT_INTERPOLATION = CombinedInterpolation([BasicInterpolation(), EnvironmentInterpolation()]) @classmethod def reload(cls): # Warn about deprecated old-style config paths. deprecated_paths = [p for p in cls._config_paths if os.path.basename(p) == "client.cfg" and os.path.exists(p)] if deprecated_paths: warnings.warn( "Luigi configuration files named 'client.cfg' are deprecated if favor of 'luigi.cfg'. " + "Found: {paths!r}".format(paths=deprecated_paths), DeprecationWarning, ) return cls.instance().read(cls._config_paths) def _get_with_default(self, method, section, option, default, expected_type=None, **kwargs): """ Gets the value of the section/option using method. Returns default if value is not found. Raises an exception if the default value is not None and doesn't match the expected_type. """ try: try: # Underscore-style is the recommended configuration style option = option.replace("-", "_") return method(self, section, option, **kwargs) except (NoOptionError, NoSectionError): # Support dash-style option names (with deprecation warning). option_alias = option.replace("_", "-") value = method(self, section, option_alias, **kwargs) warn = "Configuration [{s}] {o} (with dashes) should be avoided. Please use underscores: {u}.".format(s=section, o=option_alias, u=option) warnings.warn(warn, DeprecationWarning) return value except (NoOptionError, NoSectionError): if default is LuigiConfigParser.NO_DEFAULT: raise if expected_type is not None and default is not None and not isinstance(default, expected_type): raise return default def has_option(self, section, option): """modified has_option Check for the existence of a given option in a given section. If the specified 'section' is None or an empty string, DEFAULT is assumed. If the specified 'section' does not exist, returns False. """ # Underscore-style is the recommended configuration style option = option.replace("-", "_") if ConfigParser.has_option(self, section, option): return True # Support dash-style option names (with deprecation warning). option_alias = option.replace("_", "-") if ConfigParser.has_option(self, section, option_alias): warn = "Configuration [{s}] {o} (with dashes) should be avoided. Please use underscores: {u}.".format(s=section, o=option_alias, u=option) warnings.warn(warn, DeprecationWarning) return True return False def get(self, section, option, default=NO_DEFAULT, **kwargs): return self._get_with_default(ConfigParser.get, section, option, default, **kwargs) def getboolean(self, section, option, default=NO_DEFAULT): return self._get_with_default(ConfigParser.getboolean, section, option, default, bool) def getint(self, section, option, default=NO_DEFAULT): return self._get_with_default(ConfigParser.getint, section, option, default, int) def getfloat(self, section, option, default=NO_DEFAULT): return self._get_with_default(ConfigParser.getfloat, section, option, default, float) def getintdict(self, section): try: # Exclude keys from [DEFAULT] section because in general they do not hold int values return dict((key, int(value)) for key, value in self.items(section) if key not in {k for k, _ in self.items("DEFAULT")}) except NoSectionError: return {} def set(self, section, option, value=None): if not ConfigParser.has_section(self, section): ConfigParser.add_section(self, section) return ConfigParser.set(self, section, option, value) ================================================ FILE: luigi/configuration/core.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import logging import os import warnings from .cfg_parser import LuigiConfigParser from .toml_parser import LuigiTomlParser logger = logging.getLogger("luigi-interface") PARSERS = { "cfg": LuigiConfigParser, "conf": LuigiConfigParser, "ini": LuigiConfigParser, "toml": LuigiTomlParser, } DEFAULT_PARSER = "cfg" def _get_default_parser(): parser = os.environ.get("LUIGI_CONFIG_PARSER", DEFAULT_PARSER) if parser not in PARSERS: warnings.warn("Invalid parser: {parser}".format(parser=DEFAULT_PARSER)) parser = DEFAULT_PARSER return parser def _check_parser(parser_class, parser): if not parser_class.enabled: msg = "Parser not installed yet. Please, install luigi with required parser:\npip install luigi[{parser}]" raise ImportError(msg.format(parser=parser)) def get_config(parser=None): """Get configs singleton for parser""" if parser is None: parser = _get_default_parser() parser_class = PARSERS[parser] _check_parser(parser_class, parser) return parser_class.instance() def add_config_path(path): """Select config parser by file extension and add path into parser.""" if not os.path.isfile(path): warnings.warn("Config file does not exist: {path}".format(path=path)) return False # select parser by file extension default_parser = _get_default_parser() _base, ext = os.path.splitext(path) if ext and ext[1:] in PARSERS: parser = ext[1:] else: parser = default_parser parser_class = PARSERS[parser] _check_parser(parser_class, parser) if parser != default_parser: msg = "Config for {added} parser added, but used {used} parser. Set up right parser via env var: export LUIGI_CONFIG_PARSER={added}" warnings.warn(msg.format(added=parser, used=default_parser)) # add config path to parser parser_class.add_config_path(path) return True if "LUIGI_CONFIG_PATH" in os.environ: add_config_path(os.environ["LUIGI_CONFIG_PATH"]) ================================================ FILE: luigi/configuration/toml_parser.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2018 Vote Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import os.path from configparser import ConfigParser from typing import Any, Dict try: import toml toml_enabled = True except ImportError: toml_enabled = False from ..freezing import recursively_freeze from .base_parser import BaseParser class LuigiTomlParser(BaseParser, ConfigParser): NO_DEFAULT = object() enabled = bool(toml_enabled) data: Dict[str, Any] = dict() _instance = None _config_paths = [ "/etc/luigi/luigi.toml", "luigi.toml", ] @staticmethod def _update_data(data, new_data): if not new_data: return data if not data: return new_data for section, content in new_data.items(): if section not in data: data[section] = dict() data[section].update(content) return data def read(self, config_paths): self.data = dict() for path in config_paths: if os.path.isfile(path): self.data = self._update_data(self.data, toml.load(path)) # freeze dict params for section, content in self.data.items(): for key, value in content.items(): if isinstance(value, dict): self.data[section][key] = recursively_freeze(value) return self.data def get(self, section, option, default=NO_DEFAULT, **kwargs): try: return self.data[section][option] except KeyError: if default is self.NO_DEFAULT: raise return default def getboolean(self, section, option, default=NO_DEFAULT): return self.get(section, option, default) def getint(self, section, option, default=NO_DEFAULT): return self.get(section, option, default) def getfloat(self, section, option, default=NO_DEFAULT): return self.get(section, option, default) def getintdict(self, section): return self.data.get(section, {}) def set(self, section, option, value=None): if section not in self.data: self.data[section] = {} self.data[section][option] = value def has_option(self, section, option): return section in self.data and option in self.data[section] def __getitem__(self, name): return self.data[name] ================================================ FILE: luigi/contrib/__init__.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ Package containing optional and-on functionality. """ ================================================ FILE: luigi/contrib/azureblob.py ================================================ # -*- coding: utf-8 -*- # # Copyright (c) 2018 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); you may not # use this file except in compliance with the License. You may obtain a copy of # the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations under # the License. # import datetime import logging import os import tempfile from azure.storage.blob import BlobServiceClient from luigi.format import get_default_format from luigi.target import AtomicLocalFile, FileAlreadyExists, FileSystem, FileSystemTarget logger = logging.getLogger("luigi-interface") class AzureBlobClient(FileSystem): """ Create an Azure Blob Storage client for authentication. Users can create multiple storage account, each of which acts like a silo. Under each storage account, we can create a container. Inside each container, the user can create multiple blobs. For each account, there should be an account key. This account key cannot be changed and one can access all the containers and blobs under this account using the account key. Usually using an account key might not always be the best idea as the key can be leaked and cannot be revoked. The solution to this issue is to create Shared `Access Signatures` aka `sas`. A SAS can be created for an entire container or just a single blob. SAS can be revoked. """ def __init__(self, account_name=None, account_key=None, sas_token=None, **kwargs): """ :param str account_name: The storage account name. This is used to authenticate requests signed with an account key\ and to construct the storage endpoint. It is required unless a connection string is given,\ or if a custom domain is used with anonymous authentication. :param str account_key: The storage account key. This is used for shared key authentication. :param str sas_token: A shared access signature token to use to authenticate requests instead of the account key. :param dict kwargs: A key-value pair to provide additional connection options. * `protocol` - The protocol to use for requests. Defaults to https. * `connection_string` - If specified, this will override all other parameters besides request session.\ See http://azure.microsoft.com/en-us/documentation/articles/storage-configure-connection-string/ for the connection string format * `endpoint_suffix` - The host base component of the url, minus the account name. Defaults to Azure\ (core.windows.net). Override this to use the China cloud (core.chinacloudapi.cn). * `custom_domain` - The custom domain to use. This can be set in the Azure Portal. For example, ‘www.mydomain.com’. * `token_credential` - A token credential used to authenticate HTTPS requests. The token value should be updated before its expiration. """ if kwargs.get("custom_domain"): account_url = "{protocol}://{custom_domain}/{account_name}".format( protocol=kwargs.get("protocol", "https"), custom_domain=kwargs.get("custom_domain"), account_name=account_name ) else: account_url = "{protocol}://{account_name}.blob.{endpoint_suffix}".format( protocol=kwargs.get("protocol", "https"), account_name=account_name, endpoint_suffix=kwargs.get("endpoint_suffix", "core.windows.net") ) self.options = {"account_name": account_name, "account_key": account_key, "account_url": account_url, "sas_token": sas_token} self.kwargs = kwargs @property def connection(self): if self.kwargs.get("connection_string"): return BlobServiceClient.from_connection_string(conn_str=self.kwargs.get("connection_string"), **self.kwargs) else: return BlobServiceClient( account_url=self.options.get("account_url"), credential=self.options.get("account_key") or self.options.get("sas_token"), **self.kwargs ) def container_client(self, container_name): return self.connection.get_container_client(container_name) def blob_client(self, container_name, blob_name): container_client = self.container_client(container_name) return container_client.get_blob_client(blob_name) def upload(self, tmp_path, container, blob, **kwargs): logging.debug("Uploading file '{tmp_path}' to container '{container}' and blob '{blob}'".format(tmp_path=tmp_path, container=container, blob=blob)) self.create_container(container) lease = None blob_client = self.blob_client(container, blob) if blob_client.exists(): lease = blob_client.acquire_lease() try: with open(tmp_path, "rb") as data: blob_client.upload_blob(data, overwrite=True, lease=lease, progress_hook=kwargs.get("progress_callback")) finally: if lease is not None: lease.release() def download_as_bytes(self, container, blob, bytes_to_read=None): logging.debug("Downloading from container '{container}' and blob '{blob}' as bytes".format(container=container, blob=blob)) blob_client = self.blob_client(container, blob) download_stream = blob_client.download_blob(offset=0, length=bytes_to_read) if bytes_to_read else blob_client.download_blob() return download_stream.readall() def download_as_file(self, container, blob, location): logging.debug("Downloading from container '{container}' and blob '{blob}' to {location}".format(container=container, blob=blob, location=location)) blob_client = self.blob_client(container, blob) with open(location, "wb") as file: download_stream = blob_client.download_blob() file.write(download_stream.readall()) return blob_client.get_blob_properties() def create_container(self, container_name): if not self.exists(container_name): return self.connection.create_container(container_name) def delete_container(self, container_name): container_client = self.container_client(container_name) lease = container_client.acquire_lease() container_client.delete_container(lease=lease) def exists(self, path): container, blob = self.splitfilepath(path) if blob is None: return self.container_client(container).exists() else: return self.blob_client(container, blob).exists() def remove(self, path, recursive=True, skip_trash=True): if not self.exists(path): return False container, blob = self.splitfilepath(path) blob_client = self.blob_client(container, blob) lease = blob_client.acquire_lease() blob_client.delete_blob(lease=lease) return True def mkdir(self, path, parents=True, raise_if_exists=False): container, blob = self.splitfilepath(path) if raise_if_exists and self.exists(path): raise FileAlreadyExists("The Azure blob path '{blob}' already exists under container '{container}'".format(blob=blob, container=container)) def isdir(self, path): """ Azure Blob Storage has no concept of directories. It always returns False :param str path: Path of the Azure blob storage :return: False """ return False def move(self, path, dest): try: return self.copy(path, dest) and self.remove(path) except IOError: self.remove(dest) return False def copy(self, path, dest): source_container, source_blob = self.splitfilepath(path) dest_container, dest_blob = self.splitfilepath(dest) if source_container != dest_container: raise Exception( "Can't copy blob from '{source_container}' to '{dest_container}'. File can be moved within container".format( source_container=source_container, dest_container=dest_container ) ) source_blob_client = self.blob_client(source_container, source_blob) dest_blob_client = self.blob_client(dest_container, dest_blob) source_lease = source_blob_client.acquire_lease() destination_lease = dest_blob_client.acquire_lease() if self.exists(dest) else None try: return dest_blob_client.start_copy_from_url(source_url=source_blob_client.url, source_lease=source_lease, destination_lease=destination_lease) finally: source_lease.release() if destination_lease is not None: destination_lease.release() def rename_dont_move(self, path, dest): self.move(path, dest) @staticmethod def splitfilepath(filepath): splitpath = filepath.split("/") container = splitpath[0] blobsplit = splitpath[1:] blob = None if not blobsplit else "/".join(blobsplit) return container, blob class ReadableAzureBlobFile: def __init__(self, container, blob, client, download_when_reading, **kwargs): self.container = container self.blob = blob self.client = client self.closed = False self.download_when_reading = download_when_reading self.azure_blob_options = kwargs self.download_file_location = os.path.join(tempfile.mkdtemp(prefix=str(datetime.datetime.utcnow())), blob) self.fid = None def read(self, n=None): return self.client.download_as_bytes(self.container, self.blob, n) def __enter__(self): if self.download_when_reading: self.client.download_as_file(self.container, self.blob, self.download_file_location) self.fid = open(self.download_file_location) return self.fid else: return self def __exit__(self, exc_type, exc, traceback): self.close() def __del__(self): self.close() if os._exists(self.download_file_location): os.remove(self.download_file_location) def close(self): if self.download_when_reading: if self.fid is not None and not self.fid.closed: self.fid.close() self.fid = None def readable(self): return True def writable(self): return False def seekable(self): return False def seek(self, offset, whence=None): pass class AtomicAzureBlobFile(AtomicLocalFile): def __init__(self, container, blob, client, **kwargs): super(AtomicAzureBlobFile, self).__init__(os.path.join(container, blob)) self.container = container self.blob = blob self.client = client self.azure_blob_options = kwargs def move_to_final_destination(self): self.client.upload(self.tmp_path, self.container, self.blob, **self.azure_blob_options) class AzureBlobTarget(FileSystemTarget): """ Create an Azure Blob Target for storing data on Azure Blob Storage """ def __init__(self, container, blob, client=None, format=None, download_when_reading=True, **kwargs): """ :param str account_name: The storage account name. This is used to authenticate requests signed with an account key and to construct the storage endpoint. It is required unless a connection string is given, or if a custom domain is used with anonymous authentication. :param str container: The azure container in which the blob needs to be stored :param str blob: The name of the blob under container specified :param str client: An instance of :class:`.AzureBlobClient`. If none is specified, anonymous access would be used :param str format: An instance of :class:`luigi.format`. :param bool download_when_reading: Determines whether the file has to be downloaded to temporary location on disk. Defaults to `True`. Pass the argument **progress_callback** with signature *(func(current, total))* to get real time progress of upload """ super(AzureBlobTarget, self).__init__(os.path.join(container, blob)) if format is None: format = get_default_format() self.container = container self.blob = blob self.client = client or AzureBlobClient() self.format = format self.download_when_reading = download_when_reading self.azure_blob_options = kwargs @property def fs(self): """ The :py:class:`FileSystem` associated with :class:`.AzureBlobTarget` """ return self.client def open(self, mode): """ Open the target for reading or writing :param char mode: 'r' for reading and 'w' for writing. 'b' is not supported and will be stripped if used. For binary mode, use `format` :return: * :class:`.ReadableAzureBlobFile` if 'r' * :class:`.AtomicAzureBlobFile` if 'w' """ if mode not in ("r", "w"): raise ValueError("Unsupported open mode '%s'" % mode) if mode == "r": return self.format.pipe_reader(ReadableAzureBlobFile(self.container, self.blob, self.client, self.download_when_reading, **self.azure_blob_options)) else: return self.format.pipe_writer(AtomicAzureBlobFile(self.container, self.blob, self.client, **self.azure_blob_options)) ================================================ FILE: luigi/contrib/batch.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2018 Outlier Bio, LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ AWS Batch wrapper for Luigi From the AWS website: AWS Batch enables you to run batch computing workloads on the AWS Cloud. Batch computing is a common way for developers, scientists, and engineers to access large amounts of compute resources, and AWS Batch removes the undifferentiated heavy lifting of configuring and managing the required infrastructure. AWS Batch is similar to traditional batch computing software. This service can efficiently provision resources in response to jobs submitted in order to eliminate capacity constraints, reduce compute costs, and deliver results quickly. See `AWS Batch User Guide`_ for more details. To use AWS Batch, you create a jobDefinition JSON that defines a `docker run`_ command, and then submit this JSON to the API to queue up the task. Behind the scenes, AWS Batch auto-scales a fleet of EC2 Container Service instances, monitors the load on these instances, and schedules the jobs. This `boto3-powered`_ wrapper allows you to create Luigi Tasks to submit Batch ``jobDefinition``s. You can either pass a dict (mapping directly to the ``jobDefinition`` JSON) OR an Amazon Resource Name (arn) for a previously registered ``jobDefinition``. Requires: - boto3 package - Amazon AWS credentials discoverable by boto3 (e.g., by using ``aws configure`` from awscli_) - An enabled AWS Batch job queue configured to run on a compute environment. Written and maintained by Jake Feala (@jfeala) for Outlier Bio (@outlierbio) .. _`docker run`: https://docs.docker.com/reference/commandline/run .. _jobDefinition: http://http://docs.aws.amazon.com/batch/latest/userguide/job_definitions.html .. _`boto3-powered`: https://boto3.readthedocs.io .. _awscli: https://aws.amazon.com/cli .. _`AWS Batch User Guide`: http://docs.aws.amazon.com/AmazonECS/latest/developerguide/ECS_GetStarted.html """ import json import logging import random import string import time import luigi logger = logging.getLogger(__name__) try: import boto3 except ImportError: logger.warning("boto3 is not installed. BatchTasks require boto3") class BatchJobException(Exception): pass POLL_TIME = 10 def _random_id(): return "batch-job-" + "".join(random.sample(string.ascii_lowercase, 8)) class BatchClient: def __init__(self, poll_time=POLL_TIME): self.poll_time = poll_time self._client = boto3.client("batch") self._log_client = boto3.client("logs") self._queue = self.get_active_queue() def get_active_queue(self): """Get name of first active job queue""" # Get dict of active queues keyed by name queues = {q["jobQueueName"]: q for q in self._client.describe_job_queues()["jobQueues"] if q["state"] == "ENABLED" and q["status"] == "VALID"} if not queues: raise Exception("No job queues with state=ENABLED and status=VALID") # Pick the first queue as default return list(queues.keys())[0] def get_job_id_from_name(self, job_name): """Retrieve the first job ID matching the given name""" jobs = self._client.list_jobs(jobQueue=self._queue, jobStatus="RUNNING")["jobSummaryList"] matching_jobs = [job for job in jobs if job["jobName"] == job_name] if matching_jobs: return matching_jobs[0]["jobId"] def get_job_status(self, job_id): """Retrieve task statuses from ECS API :param job_id (str): AWS Batch job uuid Returns one of {SUBMITTED|PENDING|RUNNABLE|STARTING|RUNNING|SUCCEEDED|FAILED} """ response = self._client.describe_jobs(jobs=[job_id]) # Error checking status_code = response["ResponseMetadata"]["HTTPStatusCode"] if status_code != 200: msg = "Job status request received status code {0}:\n{1}" raise Exception(msg.format(status_code, response)) return response["jobs"][0]["status"] def get_logs(self, log_stream_name, get_last=50): """Retrieve log stream from CloudWatch""" response = self._log_client.get_log_events(logGroupName="/aws/batch/job", logStreamName=log_stream_name, startFromHead=False) events = response["events"] return "\n".join(e["message"] for e in events[-get_last:]) def submit_job(self, job_definition, parameters, job_name=None, queue=None): """Wrap submit_job with useful defaults""" if job_name is None: job_name = _random_id() response = self._client.submit_job(jobName=job_name, jobQueue=queue or self.get_active_queue(), jobDefinition=job_definition, parameters=parameters) return response["jobId"] def wait_on_job(self, job_id): """Poll task status until STOPPED""" while True: status = self.get_job_status(job_id) if status == "SUCCEEDED": logger.info("Batch job {} SUCCEEDED".format(job_id)) return True elif status == "FAILED": # Raise and notify if job failed jobs = self._client.describe_jobs(jobs=[job_id])["jobs"] job_str = json.dumps(jobs, indent=4) logger.debug("Job details:\n" + job_str) log_stream_name = jobs[0]["attempts"][0]["container"]["logStreamName"] logs = self.get_logs(log_stream_name) raise BatchJobException("Job {} failed: {}".format(job_id, logs)) time.sleep(self.poll_time) logger.debug("Batch job status for job {0}: {1}".format(job_id, status)) def register_job_definition(self, json_fpath): """Register a job definition with AWS Batch, using a JSON""" with open(json_fpath) as f: job_def = json.load(f) response = self._client.register_job_definition(**job_def) status_code = response["ResponseMetadata"]["HTTPStatusCode"] if status_code != 200: msg = "Register job definition request received status code {0}:\n{1}" raise Exception(msg.format(status_code, response)) return response class BatchTask(luigi.Task): """ Base class for an Amazon Batch job Amazon Batch requires you to register "job definitions", which are JSON descriptions for how to issue the ``docker run`` command. This Luigi Task requires a pre-registered Batch jobDefinition name passed as a Parameter :param job_definition (str): name of pre-registered jobDefinition :param job_name: name of specific job, for tracking in the queue and logs. :param job_queue: name of job queue where job is going to be submitted. """ job_definition = luigi.Parameter() job_name = luigi.OptionalParameter(default=None) job_queue = luigi.OptionalParameter(default=None) poll_time = luigi.IntParameter(default=POLL_TIME) def run(self): bc = BatchClient(self.poll_time) job_id = bc.submit_job(self.job_definition, self.parameters, job_name=self.job_name, queue=self.job_queue) bc.wait_on_job(job_id) @property def parameters(self): """Override to return a dict of parameters for the Batch Task""" return {} ================================================ FILE: luigi/contrib/beam_dataflow.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2019 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import abc import json import logging import os import subprocess import luigi from luigi.contrib import bigquery, gcs from luigi.task import MixinNaiveBulkComplete logger = logging.getLogger("luigi-interface") class DataflowParamKeys(metaclass=abc.ABCMeta): """ Defines the naming conventions for Dataflow execution params. For example, the Java API expects param names in lower camel case, whereas the Python implementation expects snake case. """ @property @abc.abstractmethod def runner(self): pass @property @abc.abstractmethod def project(self): pass @property @abc.abstractmethod def zone(self): pass @property @abc.abstractmethod def region(self): pass @property @abc.abstractmethod def staging_location(self): pass @property @abc.abstractmethod def temp_location(self): pass @property @abc.abstractmethod def gcp_temp_location(self): pass @property @abc.abstractmethod def num_workers(self): pass @property @abc.abstractmethod def autoscaling_algorithm(self): pass @property @abc.abstractmethod def max_num_workers(self): pass @property @abc.abstractmethod def disk_size_gb(self): pass @property @abc.abstractmethod def worker_machine_type(self): pass @property @abc.abstractmethod def worker_disk_type(self): pass @property @abc.abstractmethod def job_name(self): pass @property @abc.abstractmethod def service_account(self): pass @property @abc.abstractmethod def network(self): pass @property @abc.abstractmethod def subnetwork(self): pass @property @abc.abstractmethod def labels(self): pass class _CmdLineRunner: """ Executes a given command line class in a subprocess, logging its output. If more complex monitoring/logging is desired, user can implement their own launcher class and set it in BeamDataflowJobTask.cmd_line_runner. """ @staticmethod def run(cmd, task=None): process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True) output_lines = [] while True: line = process.stdout.readline() if not line: break line = line.decode("utf-8") output_lines += [line] logger.info(line.rstrip("\n")) process.stdout.close() exit_code = process.wait() if exit_code: output = "".join(output_lines) raise subprocess.CalledProcessError(exit_code, cmd, output=output) class BeamDataflowJobTask(MixinNaiveBulkComplete, luigi.Task, metaclass=abc.ABCMeta): """ Luigi wrapper for a Dataflow job. Must be overridden for each Beam SDK with that SDK's dataflow_executable(). For more documentation, see: https://cloud.google.com/dataflow/docs/guides/specifying-exec-params The following required Dataflow properties must be set: project # GCP project ID temp_location # Cloud storage path for temporary files The following optional Dataflow properties can be set: runner # PipelineRunner implementation for your Beam job. Default: DirectRunner num_workers # The number of workers to start the task with Default: Determined by Dataflow service autoscaling_algorithm # The Autoscaling mode for the Dataflow job Default: `THROUGHPUT_BASED` max_num_workers # Used if the autoscaling is enabled Default: Determined by Dataflow service network # Network in GCE to be used for launching workers Default: a network named "default" subnetwork # Subnetwork in GCE to be used for launching workers Default: Determined by Dataflow service disk_size_gb # Remote worker disk size. Minimum value is 30GB Default: set to 0 to use GCP project default worker_machine_type # Machine type to create Dataflow worker VMs Default: Determined by Dataflow service job_name # Custom job name, must be unique across project's active jobs worker_disk_type # Specify SSD for local disk or defaults to hard disk as a full URL of disk type resource Default: Determined by Dataflow service. service_account # Service account of Dataflow VMs/workers Default: active GCE service account region # Region to deploy Dataflow job to Default: us-central1 zone # Availability zone for launching workers instances Default: an available zone in the specified region staging_location # Cloud Storage bucket for Dataflow to stage binary files Default: the value of temp_location gcp_temp_location # Cloud Storage path for Dataflow to stage temporary files Default: the value of temp_location labels # Custom GCP labels attached to the Dataflow job Default: nothing """ project = None runner = None temp_location = None staging_location = None gcp_temp_location = None num_workers = None autoscaling_algorithm = None max_num_workers = None network = None subnetwork = None disk_size_gb = None worker_machine_type = None job_name = None worker_disk_type = None service_account = None zone = None region = None labels: dict[str, str] = {} cmd_line_runner = _CmdLineRunner dataflow_params = None def __init__(self): if not isinstance(self.dataflow_params, DataflowParamKeys): raise ValueError("dataflow_params must be of type DataflowParamKeys") super(BeamDataflowJobTask, self).__init__() @abc.abstractmethod def dataflow_executable(self): """ Command representing the Dataflow executable to be run. For example: return ['java', 'com.spotify.luigi.MyClass', '-Xmx256m'] """ pass def args(self): """ Extra String arguments that will be passed to your Dataflow job. For example: return ['--setup_file=setup.py'] """ return [] def before_run(self): """ Hook that gets called right before the Dataflow job is launched. Can be used to setup any temporary files/tables, validate input, etc. """ pass def on_successful_run(self): """ Callback that gets called right after the Dataflow job has finished successfully but before validate_output is run. """ pass def validate_output(self): """ Callback that can be used to validate your output before it is moved to its final location. Returning false here will cause the job to fail, and output to be removed instead of published. """ return True def file_pattern(self): """ If one/some of the input target files are not in the pattern of part-*, we can add the key of the required target and the correct file pattern that should be appended in the command line here. If the input target key is not found in this dict, the file pattern will be assumed to be part-* for that target. :return A dictionary of overridden file pattern that is not part-* for the inputs """ return {} def on_successful_output_validation(self): """ Callback that gets called after the Dataflow job has finished successfully if validate_output returns True. """ pass def cleanup_on_error(self, error): """ Callback that gets called after the Dataflow job has finished unsuccessfully, or validate_output returns False. """ pass def run(self): cmd_line = self._mk_cmd_line() logger.info(" ".join(cmd_line)) self.before_run() try: self.cmd_line_runner.run(cmd_line, self) except subprocess.CalledProcessError as e: logger.error(e, exc_info=True) self.cleanup_on_error(e) os._exit(e.returncode) self.on_successful_run() if self.validate_output(): self.on_successful_output_validation() else: error = ValueError("Output validation failed") self.cleanup_on_error(error) raise error def _mk_cmd_line(self): cmd_line = self.dataflow_executable() cmd_line.extend(self._get_dataflow_args()) cmd_line.extend(self.args()) cmd_line.extend(self._format_input_args()) cmd_line.extend(self._format_output_args()) return cmd_line def _get_runner(self): if not self.runner: logger.warning("Runner not supplied to BeamDataflowJobTask. " + "Defaulting to DirectRunner.") return "DirectRunner" elif self.runner in ["DataflowRunner", "DirectRunner"]: return self.runner else: raise ValueError("Runner %s is unsupported." % self.runner) def _get_dataflow_args(self): def f(key, value): return "--{}={}".format(key, value) output = [] output.append(f(self.dataflow_params.runner, self._get_runner())) if self.project: output.append(f(self.dataflow_params.project, self.project)) if self.zone: output.append(f(self.dataflow_params.zone, self.zone)) if self.region: output.append(f(self.dataflow_params.region, self.region)) if self.staging_location: output.append(f(self.dataflow_params.staging_location, self.staging_location)) if self.temp_location: output.append(f(self.dataflow_params.temp_location, self.temp_location)) if self.gcp_temp_location: output.append(f(self.dataflow_params.gcp_temp_location, self.gcp_temp_location)) if self.num_workers: output.append(f(self.dataflow_params.num_workers, self.num_workers)) if self.autoscaling_algorithm: output.append(f(self.dataflow_params.autoscaling_algorithm, self.autoscaling_algorithm)) if self.max_num_workers: output.append(f(self.dataflow_params.max_num_workers, self.max_num_workers)) if self.disk_size_gb: output.append(f(self.dataflow_params.disk_size_gb, self.disk_size_gb)) if self.worker_machine_type: output.append(f(self.dataflow_params.worker_machine_type, self.worker_machine_type)) if self.worker_disk_type: output.append(f(self.dataflow_params.worker_disk_type, self.worker_disk_type)) if self.network: output.append(f(self.dataflow_params.network, self.network)) if self.subnetwork: output.append(f(self.dataflow_params.subnetwork, self.subnetwork)) if self.job_name: output.append(f(self.dataflow_params.job_name, self.job_name)) if self.service_account: output.append(f(self.dataflow_params.service_account, self.service_account)) if self.labels: output.append(f(self.dataflow_params.labels, json.dumps(self.labels))) return output def _format_input_args(self): """ Parses the result(s) of self.input() into a string-serialized key-value list passed to the Dataflow job. Valid inputs include: return FooTarget() return {"input1": FooTarget(), "input2": FooTarget2()) return ("input", FooTarget()) return [("input1", FooTarget()), ("input2": FooTarget2())] return [FooTarget(), FooTarget2()] Unlabeled input are passed in with under the default key "input". """ job_input = self.input() if isinstance(job_input, luigi.Target): job_input = {"input": job_input} elif isinstance(job_input, tuple): job_input = {job_input[0]: job_input[1]} elif isinstance(job_input, list): if all(isinstance(item, tuple) for item in job_input): job_input = dict(job_input) else: job_input = {"input": job_input} elif not isinstance(job_input, dict): raise ValueError("Invalid job input requires(). Supported types: [Target, tuple of (name, Target), dict of (name: Target), list of Targets]") if not isinstance(self.file_pattern(), dict): raise ValueError("file_pattern() must return a dict type") input_args = [] for name, targets in job_input.items(): uris = [self.get_target_path(uri_target) for uri_target in luigi.task.flatten(targets)] if isinstance(targets, dict): """ If targets is a dict that means it had multiple outputs. Make the input args in that case "-" """ names = ["%s-%s" % (name, key) for key in targets.keys()] else: names = [name] * len(uris) input_dict = {} for arg_name, uri in zip(names, uris): pattern = self.file_pattern().get(name, "part-*") input_value = input_dict.get(arg_name, []) input_value.append(uri.rstrip("/") + "/" + pattern) input_dict[arg_name] = input_value for key, paths in input_dict.items(): input_args.append("--%s=%s" % (key, ",".join(paths))) return input_args def _format_output_args(self): """ Parses the result(s) of self.output() into a string-serialized key-value list passed to the Dataflow job. Valid outputs include: return FooTarget() return {"output1": FooTarget(), "output2": FooTarget2()} Unlabeled outputs are passed in with under the default key "output". """ job_output = self.output() if isinstance(job_output, luigi.Target): job_output = {"output": job_output} elif not isinstance(job_output, dict): raise ValueError("Task output must be a Target or a dict from String to Target") output_args = [] for name, target in job_output.items(): uri = self.get_target_path(target) output_args.append("--%s=%s" % (name, uri)) return output_args @staticmethod def get_target_path(target): """ Given a luigi Target, determine a stringly typed path to pass as a Dataflow job argument. """ if isinstance(target, luigi.LocalTarget) or isinstance(target, gcs.GCSTarget): return target.path elif isinstance(target, bigquery.BigQueryTarget): return "{}:{}.{}".format(target.table.project_id, target.table.dataset_id, target.table.table_id) else: raise ValueError("Target %s not supported" % target) ================================================ FILE: luigi/contrib/bigquery.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2015 Twitter Inc # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations import collections import logging import time from tenacity import retry, retry_if_exception, retry_if_exception_type, stop_after_attempt, wait_exponential import luigi.target from luigi.contrib import gcp logger = logging.getLogger("luigi-interface") RETRYABLE_ERRORS: tuple[type[BaseException], ...] = () try: import httplib2 from googleapiclient import discovery, errors, http except ImportError: logger.warning("BigQuery module imported, but google-api-python-client is not installed. Any BigQuery task will fail") else: RETRYABLE_ERRORS = (httplib2.HttpLib2Error, IOError, TimeoutError, BrokenPipeError) # Retry configurations. For more details, see https://tenacity.readthedocs.io/en/latest/ def is_error_5xx(err): return isinstance(err, errors.HttpError) and err.resp.status >= 500 bq_retry = retry( retry=(retry_if_exception(is_error_5xx) | retry_if_exception_type(RETRYABLE_ERRORS)), wait=wait_exponential(multiplier=1, min=1, max=10), stop=stop_after_attempt(3), reraise=True, after=lambda x: x.args[0]._initialise_client(), ) class CreateDisposition: CREATE_IF_NEEDED = "CREATE_IF_NEEDED" CREATE_NEVER = "CREATE_NEVER" class WriteDisposition: WRITE_TRUNCATE = "WRITE_TRUNCATE" WRITE_APPEND = "WRITE_APPEND" WRITE_EMPTY = "WRITE_EMPTY" class QueryMode: INTERACTIVE = "INTERACTIVE" BATCH = "BATCH" class SourceFormat: AVRO = "AVRO" CSV = "CSV" DATASTORE_BACKUP = "DATASTORE_BACKUP" NEWLINE_DELIMITED_JSON = "NEWLINE_DELIMITED_JSON" PARQUET = "PARQUET" class FieldDelimiter: """ The separator for fields in a CSV file. The separator can be any ISO-8859-1 single-byte character. To use a character in the range 128-255, you must encode the character as UTF8. BigQuery converts the string to ISO-8859-1 encoding, and then uses the first byte of the encoded string to split the data in its raw, binary state. BigQuery also supports the escape sequence "\t" to specify a tab separator. The default value is a comma (','). https://cloud.google.com/bigquery/docs/reference/v2/jobs#configuration.load """ COMMA = "," # Default TAB = "\t" PIPE = "|" class PrintHeader: TRUE = True FALSE = False class DestinationFormat: AVRO = "AVRO" CSV = "CSV" NEWLINE_DELIMITED_JSON = "NEWLINE_DELIMITED_JSON" class Compression: GZIP = "GZIP" NONE = "NONE" class Encoding: """ [Optional] The character encoding of the data. The supported values are UTF-8 or ISO-8859-1. The default value is UTF-8. BigQuery decodes the data after the raw, binary data has been split using the values of the quote and fieldDelimiter properties. """ UTF_8 = "UTF-8" ISO_8859_1 = "ISO-8859-1" BQDataset = collections.namedtuple("BQDataset", "project_id dataset_id location") class BQTable(collections.namedtuple("BQTable", "project_id dataset_id table_id location")): @property def dataset(self): return BQDataset(project_id=self.project_id, dataset_id=self.dataset_id, location=self.location) @property def uri(self): return "bq://" + self.project_id + "/" + self.dataset.dataset_id + "/" + self.table_id class BigQueryClient: """A client for Google BigQuery. For details of how authentication and the descriptor work, see the documentation for the GCS client. The descriptor URL for BigQuery is https://www.googleapis.com/discovery/v1/apis/bigquery/v2/rest """ def __init__(self, oauth_credentials=None, descriptor="", http_=None): # Save initialisation arguments in case we need to re-create client # due to connection timeout self.oauth_credentials = oauth_credentials self.descriptor = descriptor self.http_ = http_ self._initialise_client() def _initialise_client(self): authenticate_kwargs = gcp.get_authenticate_kwargs(self.oauth_credentials, self.http_) if self.descriptor: self.client = discovery.build_from_document(self.descriptor, **authenticate_kwargs) else: self.client = discovery.build("bigquery", "v2", cache_discovery=False, **authenticate_kwargs) @bq_retry def dataset_exists(self, dataset): """Returns whether the given dataset exists. If regional location is specified for the dataset, that is also checked to be compatible with the remote dataset, otherwise an exception is thrown. :param dataset: :type dataset: BQDataset """ try: response = self.client.datasets().get(projectId=dataset.project_id, datasetId=dataset.dataset_id).execute() if dataset.location is not None: fetched_location = response.get("location") if dataset.location != fetched_location: raise Exception( """Dataset already exists with regional location {}. Can't use {}.""".format( fetched_location if fetched_location is not None else "unspecified", dataset.location ) ) except http.HttpError as ex: if ex.resp.status == 404: return False raise return True @bq_retry def table_exists(self, table): """Returns whether the given table exists. :param table: :type table: BQTable """ if not self.dataset_exists(table.dataset): return False try: self.client.tables().get(projectId=table.project_id, datasetId=table.dataset_id, tableId=table.table_id).execute() except http.HttpError as ex: if ex.resp.status == 404: return False raise return True def make_dataset(self, dataset, raise_if_exists=False, body=None): """Creates a new dataset with the default permissions. :param dataset: :type dataset: BQDataset :param raise_if_exists: whether to raise an exception if the dataset already exists. :raises luigi.target.FileAlreadyExists: if raise_if_exists=True and the dataset exists """ if body is None: body = {} try: # Construct a message body in the format required by # https://developers.google.com/resources/api-libraries/documentation/bigquery/v2/python/latest/bigquery_v2.datasets.html#insert body["datasetReference"] = {"projectId": dataset.project_id, "datasetId": dataset.dataset_id} if dataset.location is not None: body["location"] = dataset.location self.client.datasets().insert(projectId=dataset.project_id, body=body).execute() except http.HttpError as ex: if ex.resp.status == 409: if raise_if_exists: raise luigi.target.FileAlreadyExists() else: raise def delete_dataset(self, dataset, delete_nonempty=True): """Deletes a dataset (and optionally any tables in it), if it exists. :param dataset: :type dataset: BQDataset :param delete_nonempty: if true, will delete any tables before deleting the dataset """ if not self.dataset_exists(dataset): return self.client.datasets().delete(projectId=dataset.project_id, datasetId=dataset.dataset_id, deleteContents=delete_nonempty).execute() def delete_table(self, table): """Deletes a table, if it exists. :param table: :type table: BQTable """ if not self.table_exists(table): return self.client.tables().delete(projectId=table.project_id, datasetId=table.dataset_id, tableId=table.table_id).execute() def list_datasets(self, project_id): """Returns the list of datasets in a given project. :param project_id: :type project_id: str """ request = self.client.datasets().list(projectId=project_id, maxResults=1000) response = request.execute() while response is not None: for ds in response.get("datasets", []): yield ds["datasetReference"]["datasetId"] request = self.client.datasets().list_next(request, response) if request is None: break response = request.execute() def list_tables(self, dataset): """Returns the list of tables in a given dataset. :param dataset: :type dataset: BQDataset """ request = self.client.tables().list(projectId=dataset.project_id, datasetId=dataset.dataset_id, maxResults=1000) response = request.execute() while response is not None: for t in response.get("tables", []): yield t["tableReference"]["tableId"] request = self.client.tables().list_next(request, response) if request is None: break response = request.execute() def get_view(self, table): """Returns the SQL query for a view, or None if it doesn't exist or is not a view. :param table: The table containing the view. :type table: BQTable """ request = self.client.tables().get(projectId=table.project_id, datasetId=table.dataset_id, tableId=table.table_id) try: response = request.execute() except http.HttpError as ex: if ex.resp.status == 404: return None raise return response["view"]["query"] if "view" in response else None def update_view(self, table, view): """Updates the SQL query for a view. If the output table exists, it is replaced with the supplied view query. Otherwise a new table is created with this view. :param table: The table to contain the view. :type table: BQTable :param view: The SQL query for the view. :type view: str """ body = {"tableReference": {"projectId": table.project_id, "datasetId": table.dataset_id, "tableId": table.table_id}, "view": {"query": view}} if self.table_exists(table): self.client.tables().update(projectId=table.project_id, datasetId=table.dataset_id, tableId=table.table_id, body=body).execute() else: self.client.tables().insert(projectId=table.project_id, datasetId=table.dataset_id, body=body).execute() def run_job(self, project_id, body, dataset=None): """Runs a BigQuery "job". See the documentation for the format of body. .. note:: You probably don't need to use this directly. Use the tasks defined below. :param dataset: :type dataset: BQDataset :return: the job id of the job. :rtype: str :raises luigi.contrib.BigQueryExecutionError: if the job fails. """ if dataset and not self.dataset_exists(dataset): self.make_dataset(dataset) new_job = self.client.jobs().insert(projectId=project_id, body=body).execute() job_id = new_job["jobReference"]["jobId"] logger.info("Started import job %s:%s", project_id, job_id) while True: status = self.client.jobs().get(projectId=project_id, jobId=job_id).execute(num_retries=10) if status["status"]["state"] == "DONE": if status["status"].get("errorResult"): raise BigQueryExecutionError(job_id, status["status"]["errorResult"]) return job_id logger.info("Waiting for job %s:%s to complete...", project_id, job_id) time.sleep(5) def copy(self, source_table, dest_table, create_disposition=CreateDisposition.CREATE_IF_NEEDED, write_disposition=WriteDisposition.WRITE_TRUNCATE): """Copies (or appends) a table to another table. :param source_table: :type source_table: BQTable :param dest_table: :type dest_table: BQTable :param create_disposition: whether to create the table if needed :type create_disposition: CreateDisposition :param write_disposition: whether to append/truncate/fail if the table exists :type write_disposition: WriteDisposition """ job = { "configuration": { "copy": { "sourceTable": { "projectId": source_table.project_id, "datasetId": source_table.dataset_id, "tableId": source_table.table_id, }, "destinationTable": { "projectId": dest_table.project_id, "datasetId": dest_table.dataset_id, "tableId": dest_table.table_id, }, "createDisposition": create_disposition, "writeDisposition": write_disposition, } } } self.run_job(dest_table.project_id, job, dataset=dest_table.dataset) class BigQueryTarget(luigi.target.Target): def __init__(self, project_id, dataset_id, table_id, client=None, location=None): self.table = BQTable(project_id=project_id, dataset_id=dataset_id, table_id=table_id, location=location) self.client = client or BigQueryClient() @classmethod def from_bqtable(cls, table, client=None): """A constructor that takes a :py:class:`BQTable`. :param table: :type table: BQTable """ return cls(table.project_id, table.dataset_id, table.table_id, client=client) def exists(self): return self.client.table_exists(self.table) def __str__(self): return str(self.table) class MixinBigQueryBulkComplete: """ Allows to efficiently check if a range of BigQueryTargets are complete. This enables scheduling tasks with luigi range tools. If you implement a custom Luigi task with a BigQueryTarget output, make sure to also inherit from this mixin to enable range support. """ @classmethod def bulk_complete(cls, parameter_tuples): # Instantiate the tasks to inspect them tasks_with_params = [(cls(p), p) for p in parameter_tuples] if not tasks_with_params: return # Grab the set of BigQuery datasets we are interested in datasets = {t.output().table.dataset for t, p in tasks_with_params} logger.info("Checking datasets %s for available tables", datasets) # Query the available tables for all datasets client = tasks_with_params[0][0].output().client available_datasets = filter(client.dataset_exists, datasets) available_tables = {d: set(client.list_tables(d)) for d in available_datasets} # Return parameter_tuples belonging to available tables for t, p in tasks_with_params: table = t.output().table if table.table_id in available_tables.get(table.dataset, []): yield p class BigQueryLoadTask(MixinBigQueryBulkComplete, luigi.Task): """Load data into BigQuery from GCS.""" @property def source_format(self): """The source format to use (see :py:class:`SourceFormat`).""" return SourceFormat.NEWLINE_DELIMITED_JSON @property def encoding(self): """The encoding of the data that is going to be loaded (see :py:class:`Encoding`).""" return Encoding.UTF_8 @property def write_disposition(self): """What to do if the table already exists. By default this will fail the job. See :py:class:`WriteDisposition`""" return WriteDisposition.WRITE_EMPTY @property def schema(self): """Schema in the format defined at https://cloud.google.com/bigquery/docs/reference/v2/jobs#configuration.load.schema. If the value is falsy, it is omitted and inferred by BigQuery.""" return [] @property def max_bad_records(self): """The maximum number of bad records that BigQuery can ignore when reading data. If the number of bad records exceeds this value, an invalid error is returned in the job result.""" return 0 @property def field_delimiter(self): """The separator for fields in a CSV file. The separator can be any ISO-8859-1 single-byte character.""" return FieldDelimiter.COMMA def source_uris(self): """The fully-qualified URIs that point to your data in Google Cloud Storage. Each URI can contain one '*' wildcard character and it must come after the 'bucket' name.""" return [x.path for x in luigi.task.flatten(self.input())] @property def skip_leading_rows(self): """The number of rows at the top of a CSV file that BigQuery will skip when loading the data. The default value is 0. This property is useful if you have header rows in the file that should be skipped.""" return 0 @property def allow_jagged_rows(self): """Accept rows that are missing trailing optional columns. The missing values are treated as nulls. If false, records with missing trailing columns are treated as bad records, and if there are too many bad records, an invalid error is returned in the job result. The default value is false. Only applicable to CSV, ignored for other formats.""" return False @property def ignore_unknown_values(self): """Indicates if BigQuery should allow extra values that are not represented in the table schema. If true, the extra values are ignored. If false, records with extra columns are treated as bad records, and if there are too many bad records, an invalid error is returned in the job result. The default value is false. The sourceFormat property determines what BigQuery treats as an extra value: CSV: Trailing columns JSON: Named values that don't match any column names""" return False @property def allow_quoted_new_lines(self): """Indicates if BigQuery should allow quoted data sections that contain newline characters in a CSV file. The default value is false.""" return False def configure_job(self, configuration): """Set additional job configuration. This allows to specify job configuration parameters that are not exposed via Task properties. :param configuration: Current configuration. :return: New or updated configuration. """ return configuration def run(self): output = self.output() assert isinstance(output, BigQueryTarget), "Output must be a BigQueryTarget, not %s" % (output) bq_client = output.client source_uris = self.source_uris() assert all(x.startswith("gs://") for x in source_uris) job = { "configuration": { "load": { "destinationTable": { "projectId": output.table.project_id, "datasetId": output.table.dataset_id, "tableId": output.table.table_id, }, "encoding": self.encoding, "sourceFormat": self.source_format, "writeDisposition": self.write_disposition, "sourceUris": source_uris, "maxBadRecords": self.max_bad_records, "ignoreUnknownValues": self.ignore_unknown_values, } } } if self.source_format == SourceFormat.CSV: job["configuration"]["load"]["fieldDelimiter"] = self.field_delimiter job["configuration"]["load"]["skipLeadingRows"] = self.skip_leading_rows job["configuration"]["load"]["allowJaggedRows"] = self.allow_jagged_rows job["configuration"]["load"]["allowQuotedNewlines"] = self.allow_quoted_new_lines if self.schema: job["configuration"]["load"]["schema"] = {"fields": self.schema} else: job["configuration"]["load"]["autodetect"] = True job["configuration"] = self.configure_job(job["configuration"]) bq_client.run_job(output.table.project_id, job, dataset=output.table.dataset) class BigQueryRunQueryTask(MixinBigQueryBulkComplete, luigi.Task): @property def write_disposition(self): """What to do if the table already exists. By default this will fail the job. See :py:class:`WriteDisposition`""" return WriteDisposition.WRITE_TRUNCATE @property def create_disposition(self): """Whether to create the table or not. See :py:class:`CreateDisposition`""" return CreateDisposition.CREATE_IF_NEEDED @property def flatten_results(self): """Flattens all nested and repeated fields in the query results. allowLargeResults must be true if this is set to False.""" return True @property def query(self): """The query, in text form.""" raise NotImplementedError() @property def query_mode(self): """The query mode. See :py:class:`QueryMode`.""" return QueryMode.INTERACTIVE @property def udf_resource_uris(self): """Iterator of code resource to load from a Google Cloud Storage URI (gs://bucket/path).""" return [] @property def use_legacy_sql(self): """Whether to use legacy SQL""" return True def configure_job(self, configuration): """Set additional job configuration. This allows to specify job configuration parameters that are not exposed via Task properties. :param configuration: Current configuration. :return: New or updated configuration. """ return configuration def run(self): output = self.output() assert isinstance(output, BigQueryTarget), "Output must be a BigQueryTarget, not %s" % (output) query = self.query assert query, "No query was provided" bq_client = output.client logger.info("Launching Query") logger.info("Query destination: %s (%s)", output, self.write_disposition) logger.info("Query SQL: %s", query) job = { "configuration": { "query": { "query": query, "priority": self.query_mode, "destinationTable": { "projectId": output.table.project_id, "datasetId": output.table.dataset_id, "tableId": output.table.table_id, }, "allowLargeResults": True, "createDisposition": self.create_disposition, "writeDisposition": self.write_disposition, "flattenResults": self.flatten_results, "userDefinedFunctionResources": [{"resourceUri": v} for v in self.udf_resource_uris], "useLegacySql": self.use_legacy_sql, } } } job["configuration"] = self.configure_job(job["configuration"]) bq_client.run_job(output.table.project_id, job, dataset=output.table.dataset) class BigQueryCreateViewTask(luigi.Task): """ Creates (or updates) a view in BigQuery. The output of this task needs to be a BigQueryTarget. Instances of this class should specify the view SQL in the view property. If a view already exist in BigQuery at output(), it will be updated. """ @property def view(self): """The SQL query for the view, in text form.""" raise NotImplementedError() def complete(self): output = self.output() assert isinstance(output, BigQueryTarget), "Output must be a BigQueryTarget, not %s" % (output) if not output.exists(): return False existing_view = output.client.get_view(output.table) return existing_view == self.view def run(self): output = self.output() assert isinstance(output, BigQueryTarget), "Output must be a BigQueryTarget, not %s" % (output) view = self.view assert view, "No view was provided" logger.info("Create view") logger.info("Destination: %s", output) logger.info("View SQL: %s", view) output.client.update_view(output.table, view) class ExternalBigQueryTask(MixinBigQueryBulkComplete, luigi.ExternalTask): """ An external task for a BigQuery target. """ pass class BigQueryExtractTask(luigi.Task): """ Extracts (unloads) a table from BigQuery to GCS. This tasks requires the input to be exactly one BigQueryTarget while the output should be one or more GCSTargets from luigi.contrib.gcs depending on the use of destinationUris property. """ @property def destination_uris(self): """ The fully-qualified URIs that point to your data in Google Cloud Storage. Each URI can contain one '*' wildcard character and it must come after the 'bucket' name. Wildcarded destinationUris in GCSQueryTarget might not be resolved correctly and result in incomplete data. If a GCSQueryTarget is used to pass wildcarded destinationUris be sure to overwrite this property to suppress the warning. """ return [x.path for x in luigi.task.flatten(self.output())] @property def print_header(self): """Whether to print the header or not.""" return PrintHeader.TRUE @property def field_delimiter(self): """ The separator for fields in a CSV file. The separator can be any ISO-8859-1 single-byte character. """ return FieldDelimiter.COMMA @property def destination_format(self): """ The destination format to use (see :py:class:`DestinationFormat`). """ return DestinationFormat.CSV @property def compression(self): """Whether to use compression.""" return Compression.NONE def configure_job(self, configuration): """Set additional job configuration. This allows to specify job configuration parameters that are not exposed via Task properties. :param configuration: Current configuration. :return: New or updated configuration. """ return configuration def run(self): input = luigi.task.flatten(self.input())[0] assert isinstance(input, BigQueryTarget) or (len(input) == 1 and isinstance(input[0], BigQueryTarget)), ( "Input must be exactly one BigQueryTarget, not %s" % (input) ) bq_client = input.client destination_uris = self.destination_uris assert all(x.startswith("gs://") for x in destination_uris) logger.info("Launching Extract Job") logger.info("Extract source: %s", input) logger.info("Extract destination: %s", destination_uris) job = { "configuration": { "extract": { "sourceTable": {"projectId": input.table.project_id, "datasetId": input.table.dataset_id, "tableId": input.table.table_id}, "destinationUris": destination_uris, "destinationFormat": self.destination_format, "compression": self.compression, } } } if self.destination_format == "CSV": # "Only exports to CSV may specify a field delimiter." job["configuration"]["extract"]["printHeader"] = self.print_header job["configuration"]["extract"]["fieldDelimiter"] = self.field_delimiter job["configuration"] = self.configure_job(job["configuration"]) bq_client.run_job(input.table.project_id, job, dataset=input.table.dataset) # the original inconsistently capitalized aliases, for backwards compatibility BigqueryClient = BigQueryClient BigqueryTarget = BigQueryTarget MixinBigqueryBulkComplete = MixinBigQueryBulkComplete BigqueryLoadTask = BigQueryLoadTask BigqueryRunQueryTask = BigQueryRunQueryTask BigqueryCreateViewTask = BigQueryCreateViewTask ExternalBigqueryTask = ExternalBigQueryTask class BigQueryExecutionError(Exception): def __init__(self, job_id, error_message) -> None: """ :param job_id: BigQuery Job ID :type job_id: str :param error_message: status['status']['errorResult'] for the failed job :type error_message: str """ super().__init__("BigQuery job {} failed: {}".format(job_id, error_message)) self.error_message = error_message self.job_id = job_id ================================================ FILE: luigi/contrib/bigquery_avro.py ================================================ """Specialized tasks for handling Avro data in BigQuery from GCS.""" import logging from luigi.contrib.bigquery import BigQueryLoadTask, SourceFormat from luigi.contrib.gcs import GCSClient from luigi.task import flatten logger = logging.getLogger("luigi-interface") try: import avro import avro.datafile except ImportError: logger.warning("bigquery_avro module imported, but avro is not installed. Any BigQueryLoadAvro task will fail to propagate schema documentation") class BigQueryLoadAvro(BigQueryLoadTask): """A helper for loading specifically Avro data into BigQuery from GCS. Copies table level description from Avro schema doc, BigQuery internally will copy field-level descriptions to the table. Suitable for use via subclassing: override requires() to return Task(s) that output to GCS Targets; their paths are expected to be URIs of .avro files or URI prefixes (GCS "directories") containing one or many .avro files. Override output() to return a BigQueryTarget representing the destination table. """ source_format = SourceFormat.AVRO def _avro_uri(self, target): path_or_uri = target.uri if hasattr(target, "uri") else target.path return path_or_uri if path_or_uri.endswith(".avro") else path_or_uri.rstrip("/") + "/*.avro" def source_uris(self): return [self._avro_uri(x) for x in flatten(self.input())] def _get_input_schema(self): """Arbitrarily picks an object in input and reads the Avro schema from it.""" assert avro, "avro module required" input_target = flatten(self.input())[0] input_fs = input_target.fs if hasattr(input_target, "fs") else GCSClient() input_uri = self.source_uris()[0] if "*" in input_uri: file_uris = list(input_fs.list_wildcard(input_uri)) if file_uris: input_uri = file_uris[0] else: raise RuntimeError("No match for " + input_uri) schema = [] exception_reading_schema = [] def read_schema(fp): # fp contains the file part downloaded thus far. We rely on that the DataFileReader # initializes itself fine as soon as the file header with schema is downloaded, without # requiring the remainder of the file... try: reader = avro.datafile.DataFileReader(fp, avro.io.DatumReader()) schema[:] = [BigQueryLoadAvro._get_writer_schema(reader.datum_reader)] except Exception as e: # Save but assume benign unless schema reading ultimately fails. The benign # exception in case of insufficiently big downloaded file part seems to be: # TypeError('ord() expected a character, but string of length 0 found',). exception_reading_schema[:] = [e] return False return True input_fs.download(input_uri, 64 * 1024, read_schema).close() if not schema: raise exception_reading_schema[0] return schema[0] @staticmethod def _get_writer_schema(datum_reader): """Python-version agnostic getter for datum_reader writer(s)_schema attribute Parameters: datum_reader (avro.io.DatumReader): DatumReader Returns: Returning correct attribute name depending on Python version. """ return datum_reader.writer_schema def _set_output_doc(self, avro_schema): bq_client = self.output().client.client table = self.output().table patch = { "description": avro_schema.doc, } bq_client.tables().patch(projectId=table.project_id, datasetId=table.dataset_id, tableId=table.table_id, body=patch).execute() def run(self): super(BigQueryLoadAvro, self).run() # We propagate documentation in one fire-and-forget attempt; the output table is # left to exist without documentation if this step raises an exception. try: self._set_output_doc(self._get_input_schema()) except Exception as e: logger.warning("Could not propagate Avro doc to BigQuery table description: %r", e) ================================================ FILE: luigi/contrib/datadog_metric.py ================================================ import logging from luigi import parameter from luigi.metrics import MetricsCollector from luigi.task import Config logger = logging.getLogger("luigi-interface") try: from datadog import api, initialize, statsd except ImportError: logger.warning("Loading datadog module without datadog installed. Will crash at runtime if datadog functionality is used.") class datadog(Config): api_key = parameter.Parameter(default="dummy_api_key", description="API key provided by Datadog") app_key = parameter.Parameter(default="dummy_app_key", description="APP key provided by Datadog") default_tags = parameter.Parameter(default="application:luigi", description="Default tags for every events and metrics sent to Datadog") environment = parameter.Parameter(default="development", description="Environment of which the pipeline is ran from (eg: 'production', 'staging', ...") metric_namespace = parameter.Parameter(default="luigi", description="Default namespace for events and metrics (eg: 'luigi' for 'luigi.task.started')") statsd_host = parameter.Parameter(default="localhost", description="StatsD host implementing the Datadog service") statsd_port = parameter.IntParameter(default=8125, description="StatsD port implementing the Datadog service") class DatadogMetricsCollector(MetricsCollector): def __init__(self, *args, **kwargs): self._config = datadog(**kwargs) initialize(api_key=self._config.api_key, app_key=self._config.app_key, statsd_host=self._config.statsd_host, statsd_port=self._config.statsd_port) def handle_task_started(self, task): title = "Luigi: A task has been started!" text = "A task has been started in the pipeline named: {name}".format(name=task.family) tags = ["task_name:{name}".format(name=task.family)] + self._format_task_params_to_tags(task) self._send_increment("task.started", tags=tags) event_tags = tags + ["task_state:STARTED"] self._send_event(title=title, text=text, tags=event_tags, alert_type="info", priority="low") def handle_task_failed(self, task): title = "Luigi: A task has failed!" text = "A task has failed in the pipeline named: {name}".format(name=task.family) tags = ["task_name:{name}".format(name=task.family)] + self._format_task_params_to_tags(task) self._send_increment("task.failed", tags=tags) event_tags = tags + ["task_state:FAILED"] self._send_event(title=title, text=text, tags=event_tags, alert_type="error", priority="normal") def handle_task_disabled(self, task, config): title = "Luigi: A task has been disabled!" lines = ["A task has been disabled in the pipeline named: {name}."] lines.append("The task has failed {failures} times in the last {window}") lines.append("seconds, so it is being disabled for {persist} seconds.") preformated_text = " ".join(lines) text = preformated_text.format(name=task.family, persist=config.disable_persist, failures=config.retry_count, window=config.disable_window) tags = ["task_name:{name}".format(name=task.family)] + self._format_task_params_to_tags(task) self._send_increment("task.disabled", tags=tags) event_tags = tags + ["task_state:DISABLED"] self._send_event(title=title, text=text, tags=event_tags, alert_type="error", priority="normal") def handle_task_done(self, task): # The task is already done -- Let's not re-create an event if task.time_running is None: return title = "Luigi: A task has been completed!" text = "A task has completed in the pipeline named: {name}".format(name=task.family) tags = ["task_name:{name}".format(name=task.family)] + self._format_task_params_to_tags(task) time_elapse = task.updated - task.time_running self._send_increment("task.done", tags=tags) self._send_gauge("task.execution_time", time_elapse, tags=tags) event_tags = tags + ["task_state:DONE"] self._send_event(title=title, text=text, tags=event_tags, alert_type="info", priority="low") def _send_event(self, **params): params["tags"] += self.default_tags api.Event.create(**params) def _send_gauge(self, metric_name, value, tags=[]): all_tags = tags + self.default_tags namespaced_metric = "{namespace}.{metric_name}".format(namespace=self._config.metric_namespace, metric_name=metric_name) statsd.gauge(namespaced_metric, value, tags=all_tags) def _send_increment(self, metric_name, value=1, tags=[]): all_tags = tags + self.default_tags namespaced_metric = "{namespace}.{metric_name}".format(namespace=self._config.metric_namespace, metric_name=metric_name) statsd.increment(namespaced_metric, value, tags=all_tags) def _format_task_params_to_tags(self, task): params = [] for key, value in task.params.items(): params.append("{key}:{value}".format(key=key, value=value)) return params @property def default_tags(self): default_tags = [] env_tag = "environment:{environment}".format(environment=self._config.environment) default_tags.append(env_tag) if self._config.default_tags: default_tags = default_tags + str.split(self._config.default_tags, ",") return default_tags ================================================ FILE: luigi/contrib/dataproc.py ================================================ """luigi bindings for Google Dataproc on Google Cloud""" import logging import os import time import luigi from luigi.contrib import gcp logger = logging.getLogger("luigi-interface") _dataproc_client = None try: import google.auth from googleapiclient import discovery from googleapiclient.errors import HttpError DEFAULT_CREDENTIALS, _ = google.auth.default() authenticate_kwargs = gcp.get_authenticate_kwargs(DEFAULT_CREDENTIALS) _dataproc_client = discovery.build("dataproc", "v1", cache_discovery=False, **authenticate_kwargs) except ImportError: logger.warning( "Loading Dataproc module without the python packages googleapiclient & google-auth. \ This will crash at runtime if Dataproc functionality is used." ) def get_dataproc_client(): return _dataproc_client def set_dataproc_client(client): global _dataproc_client _dataproc_client = client class _DataprocBaseTask(luigi.Task): gcloud_project_id = luigi.Parameter(significant=False, positional=False) dataproc_cluster_name = luigi.Parameter(significant=False, positional=False) dataproc_region = luigi.Parameter(default="global", significant=False, positional=False) dataproc_client = get_dataproc_client() class DataprocBaseTask(_DataprocBaseTask): """ Base task for running jobs in Dataproc. It is recommended to use one of the tasks specific to your job type. Extend this class if you need fine grained control over what kind of job gets submitted to your Dataproc cluster. """ _job = None _job_name = None _job_id = None def submit_job(self, job_config): self._job = ( self.dataproc_client.projects().regions().jobs().submit(projectId=self.gcloud_project_id, region=self.dataproc_region, body=job_config).execute() ) self._job_id = self._job["reference"]["jobId"] return self._job def submit_spark_job(self, jars, main_class, job_args=None): if job_args is None: job_args = [] job_config = { "job": {"placement": {"clusterName": self.dataproc_cluster_name}, "sparkJob": {"args": job_args, "mainClass": main_class, "jarFileUris": jars}} } self.submit_job(job_config) self._job_name = os.path.basename(self._job["sparkJob"]["mainClass"]) logger.info("Submitted new dataproc job:{} id:{}".format(self._job_name, self._job_id)) return self._job def submit_pyspark_job(self, job_file, extra_files=list(), job_args=None): if job_args is None: job_args = [] job_config = { "job": { "placement": {"clusterName": self.dataproc_cluster_name}, "pysparkJob": {"mainPythonFileUri": job_file, "pythonFileUris": extra_files, "args": job_args}, } } self.submit_job(job_config) self._job_name = os.path.basename(self._job["pysparkJob"]["mainPythonFileUri"]) logger.info("Submitted new dataproc job:{} id:{}".format(self._job_name, self._job_id)) return self._job def wait_for_job(self): if self._job is None: raise Exception("You must submit a job before you can wait for it") while True: job_result = ( self.dataproc_client.projects() .regions() .jobs() .get(projectId=self.gcloud_project_id, region=self.dataproc_region, jobId=self._job_id) .execute() ) status = job_result["status"]["state"] logger.info("Current dataproc status: {} job:{} id:{}".format(status, self._job_name, self._job_id)) if status == "DONE": break if status == "ERROR": raise Exception(job_result["status"]["details"]) time.sleep(5) class DataprocSparkTask(DataprocBaseTask): """ Runs a spark jobs on your Dataproc cluster """ main_class = luigi.Parameter() jars = luigi.Parameter(default="") job_args = luigi.Parameter(default="") def run(self): self.submit_spark_job( main_class=self.main_class, jars=self.jars.split(",") if self.jars else [], job_args=self.job_args.split(",") if self.job_args else [] ) self.wait_for_job() class DataprocPysparkTask(DataprocBaseTask): """ Runs a pyspark jobs on your Dataproc cluster """ job_file = luigi.Parameter() extra_files = luigi.Parameter(default="") job_args = luigi.Parameter(default="") def run(self): self.submit_pyspark_job( job_file=self.job_file, extra_files=self.extra_files.split(",") if self.extra_files else [], job_args=self.job_args.split(",") if self.job_args else [], ) self.wait_for_job() class CreateDataprocClusterTask(_DataprocBaseTask): """Task for creating a Dataproc cluster.""" gcloud_zone = luigi.Parameter(default="europe-west1-c") gcloud_network = luigi.Parameter(default="default") master_node_type = luigi.Parameter(default="n1-standard-2") master_disk_size = luigi.Parameter(default="100") worker_node_type = luigi.Parameter(default="n1-standard-2") worker_disk_size = luigi.Parameter(default="100") worker_normal_count = luigi.Parameter(default="2") worker_preemptible_count = luigi.Parameter(default="0") image_version = luigi.Parameter(default="") def _get_cluster_status(self): return ( self.dataproc_client.projects() .regions() .clusters() .get(projectId=self.gcloud_project_id, region=self.dataproc_region, clusterName=self.dataproc_cluster_name) .execute() ) def complete(self): try: self._get_cluster_status() return True # No (404) error so the cluster already exists except HttpError as e: if e.resp.status == 404: return False # We got a 404 so the cluster doesn't exist yet else: raise e # Something's wrong ... def run(self): base_uri = "https://www.googleapis.com/compute/v1/projects/{}".format(self.gcloud_project_id) software_config = {"imageVersion": self.image_version} if self.image_version else {} cluster_conf = { "clusterName": self.dataproc_cluster_name, "projectId": self.gcloud_project_id, "config": { "configBucket": "", "gceClusterConfig": { "networkUri": base_uri + "/global/networks/" + self.gcloud_network, "zoneUri": base_uri + "/zones/" + self.gcloud_zone, "serviceAccountScopes": ["https://www.googleapis.com/auth/cloud-platform"], }, "masterConfig": { "numInstances": 1, "machineTypeUri": base_uri + "/zones/" + self.gcloud_zone + "/machineTypes/" + self.master_node_type, "diskConfig": {"bootDiskSizeGb": self.master_disk_size, "numLocalSsds": 0}, }, "workerConfig": { "numInstances": self.worker_normal_count, "machineTypeUri": base_uri + "/zones/" + self.gcloud_zone + "/machineTypes/" + self.worker_node_type, "diskConfig": {"bootDiskSizeGb": self.worker_disk_size, "numLocalSsds": 0}, }, "secondaryWorkerConfig": {"numInstances": self.worker_preemptible_count, "isPreemptible": True}, "softwareConfig": software_config, }, } self.dataproc_client.projects().regions().clusters().create(projectId=self.gcloud_project_id, region=self.dataproc_region, body=cluster_conf).execute() while True: time.sleep(10) cluster_status = self._get_cluster_status() status = cluster_status["status"]["state"] logger.info("Creating new dataproc cluster: {} status: {}".format(self.dataproc_cluster_name, status)) if status == "RUNNING": break if status == "ERROR": raise Exception(cluster_status["status"]["details"]) class DeleteDataprocClusterTask(_DataprocBaseTask): """ Task for deleting a Dataproc cluster. One of the uses for this class is to extend it and have it require a Dataproc task that does a calculation and have that task extend the cluster creation task. This allows you to create chains where you create a cluster, run your job and remove the cluster right away. (Store your input and output files in gs://... instead of hdfs://... if you do this). """ def _get_cluster_status(self): try: return ( self.dataproc_client.projects() .regions() .clusters() .get(projectId=self.gcloud_project_id, region=self.dataproc_region, clusterName=self.dataproc_cluster_name, fields="status") .execute() ) except HttpError as e: if e.resp.status == 404: return None # We got a 404 so the cluster doesn't exist else: raise e def complete(self): return self._get_cluster_status() is None def run(self): self.dataproc_client.projects().regions().clusters().delete( projectId=self.gcloud_project_id, region=self.dataproc_region, clusterName=self.dataproc_cluster_name ).execute() while True: time.sleep(10) status = self._get_cluster_status() if status is None: logger.info("Finished shutting down cluster: {}".format(self.dataproc_cluster_name)) break logger.info("Shutting down cluster: {} current status: {}".format(self.dataproc_cluster_name, status["status"]["state"])) ================================================ FILE: luigi/contrib/docker_runner.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2017 Open Targets # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ Docker container wrapper for Luigi. Enables running a docker container as a task in luigi. This wrapper uses the Docker Python SDK to communicate directly with the Docker API avoiding the common pattern to invoke the docker client from the command line. Using the SDK it is possible to detect and properly handle errors occurring when pulling, starting or running the containers. On top of this, it is possible to mount a single file in the container and a temporary directory is created on the host and mounted allowing the handling of files bigger than the container limit. Requires: - docker: ``pip install docker`` Written and maintained by Andrea Pierleoni (@apierleoni). Contributions by Eliseo Papa (@elipapa). """ import logging from tempfile import mkdtemp import luigi from luigi.local_target import LocalFileSystem logger = logging.getLogger("luigi-interface") try: import docker from docker.errors import APIError, ContainerError, ImageNotFound except ImportError: logger.warning("docker is not installed. DockerTask requires docker.") docker = None # TODO: may need to implement this logic for remote hosts # class dockerconfig(luigi.Config): # ''' # this class allows to use the luigi.cfg file to specify the path to the docker config.json. # The docker client should look by default in the main directory, # but on different systems this may need to be specified. # ''' # docker_config_path = luigi.Parameter( # default="~/.docker/config.json", # description="Path to dockercfg file for authentication") class DockerTask(luigi.Task): @property def image(self): return "alpine" @property def command(self): return "echo hello world" @property def name(self): return None @property def host_config_options(self): """ Override this to specify host_config options like gpu requests or shm size e.g. `{"device_requests": [docker.types.DeviceRequest(count=1, capabilities=[["gpu"]])]}` See https://docker-py.readthedocs.io/en/stable/api.html#docker.api.container.ContainerApiMixin.create_host_config """ return {} @property def container_options(self): """ Override this to specify container options like user or ports e.g. `{"user": f"{os.getuid()}:{os.getgid()}"}` See https://docker-py.readthedocs.io/en/stable/api.html#docker.api.container.ContainerApiMixin.create_container """ return {} @property def environment(self): return {} @property def container_tmp_dir(self): return "/tmp/luigi" @property def binds(self): """ Override this to mount local volumes, in addition to the /tmp/luigi which gets defined by default. This should return a list of strings. e.g. ['/hostpath1:/containerpath1', '/hostpath2:/containerpath2'] """ return None @property def network_mode(self): return "" @property def docker_url(self): return None @property def auto_remove(self): return True @property def force_pull(self): return False @property def mount_tmp(self): return True def __init__(self, *args, **kwargs): """ When a new instance of the DockerTask class gets created: - call the parent class __init__ method - start the logger - init an instance of the docker client - create a tmp dir - add the temp dir to the volume binds specified in the task """ super(DockerTask, self).__init__(*args, **kwargs) self.__logger = logger """init docker client using the low level API as the higher level API does not allow to mount single files as volumes """ self._client = docker.APIClient(self.docker_url) # add latest tag if nothing else is specified by task if ":" not in self.image: self._image = ":".join([self.image, "latest"]) else: self._image = self.image if self.mount_tmp: # create a tmp_dir, NOTE: /tmp needs to be specified for it to work on # macOS, despite what the python documentation says self._host_tmp_dir = mkdtemp(suffix=self.task_id, prefix="luigi-docker-tmp-dir-", dir="/tmp") self._binds = ["{0}:{1}".format(self._host_tmp_dir, self.container_tmp_dir)] else: self._binds = [] # update environment property with the (internal) location of tmp_dir self.environment["LUIGI_TMP_DIR"] = self.container_tmp_dir # add additional volume binds specified by the user to the tmp_Dir bind if isinstance(self.binds, str): self._binds.append(self.binds) elif isinstance(self.binds, list): self._binds.extend(self.binds) # derive volumes (ie. list of container destination paths) from # specified binds self._volumes = [b.split(":")[1] for b in self._binds] def run(self): # get image if missing if self.force_pull or len(self._client.images(name=self._image)) == 0: logger.info("Pulling docker image " + self._image) try: for logline in self._client.pull(self._image, stream=True): logger.debug(logline.decode("utf-8")) except APIError as e: self.__logger.warning("Error in Docker API: " + e.explanation) raise # remove clashing container if a container with the same name exists if self.auto_remove and self.name: try: self._client.remove_container(self.name, force=True) except APIError as e: self.__logger.warning("Ignored error in Docker API: " + e.explanation) # run the container try: logger.debug("Creating image: %s command: %s volumes: %s" % (self._image, self.command, self._binds)) host_config = self._client.create_host_config(binds=self._binds, network_mode=self.network_mode, **self.host_config_options) container = self._client.create_container( self._image, command=self.command, name=self.name, environment=self.environment, volumes=self._volumes, host_config=host_config, **self.container_options, ) self._client.start(container["Id"]) exit_status = self._client.wait(container["Id"]) # docker-py>=3.0.0 returns a dict instead of the status code directly if type(exit_status) is dict: exit_status = exit_status["StatusCode"] if exit_status != 0: stdout = False stderr = True error = self._client.logs(container["Id"], stdout=stdout, stderr=stderr) if self.auto_remove: try: self._client.remove_container(container["Id"]) except docker.errors.APIError: self.__logger.warning("Container " + container["Id"] + " could not be removed") if exit_status != 0: raise ContainerError(container, exit_status, self.command, self._image, error) except ContainerError as e: # catch non zero exti status and return it container_name = "" if self.name: container_name = self.name try: message = e.message except AttributeError: message = str(e) self.__logger.error("Container " + container_name + " exited with non zero code: " + message) raise except ImageNotFound: self.__logger.error("Image " + self._image + " not found") raise except APIError as e: self.__logger.error("Error in Docker API: " + e.explanation) raise # delete temp dir filesys = LocalFileSystem() if self.mount_tmp and filesys.exists(self._host_tmp_dir): filesys.remove(self._host_tmp_dir, recursive=True) ================================================ FILE: luigi/contrib/dropbox.py ================================================ # -*- coding: utf-8 -*- # # Copyright (c) 2019 Jose-Ignacio Riaño Chico # # Licensed under the Apache License, Version 2.0 (the "License"); you may not # use this file except in compliance with the License. You may obtain a copy of # the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations under # the License. # import logging import ntpath import os import random import tempfile import time from contextlib import contextmanager from functools import wraps import luigi.format from luigi.target import AtomicLocalFile, FileSystem, FileSystemTarget logger = logging.getLogger("luigi-interface") try: import dropbox.dropbox_client import dropbox.exceptions import dropbox.files except ImportError: logger.warning( "Loading Dropbox module without the python package dropbox (https://pypi.org/project/dropbox/). Will crash at runtime if Dropbox functionality is used." ) def accept_trailing_slash_in_existing_dirpaths(func): @wraps(func) def wrapped(self, path, *args, **kwargs): if path != "/" and path.endswith("/"): logger.warning("Dropbox paths should NOT have trailing slashes. This causes additional API calls") logger.warning("Consider modifying your calls to {}, so that they don't use paths than end with '/'".format(func.__name__)) if self._exists_and_is_dir(path[:-1]): path = path[:-1] return func(self, path, *args, **kwargs) return wrapped def accept_trailing_slash(func): @wraps(func) def wrapped(self, path, *args, **kwargs): if path != "/" and path.endswith("/"): path = path[:-1] return func(self, path, *args, **kwargs) return wrapped class DropboxClient(FileSystem): """ Dropbox client for authentication, designed to be used by the :py:class:`DropboxTarget` class. """ def __init__(self, token, user_agent="Luigi", root_namespace_id=None): """ :param str token: Dropbox Oauth2 Token. See :class:`DropboxTarget` for more information about generating a token :param str root_namespace_id: Root namespace ID for interacting with Team Spaces """ if not token: raise ValueError("The token parameter must contain a valid Dropbox Oauth2 Token") try: conn = dropbox.dropbox_client.Dropbox(oauth2_access_token=token, user_agent=user_agent) except Exception as e: raise Exception("Cannot connect to Dropbox. Check your Internet connection and the token. \n" + repr(e)) if root_namespace_id: conn = conn.with_path_root(dropbox.common.PathRoot.root(root_namespace_id)) self.token = token self.conn = conn @accept_trailing_slash_in_existing_dirpaths def exists(self, path): if path == "/": return True if path.endswith("/"): path = path[:-1] return self._exists_and_is_dir(path) try: self.conn.files_get_metadata(path) return True except dropbox.exceptions.ApiError as e: if isinstance(e.error.get_path(), dropbox.files.LookupError): return False else: raise e @accept_trailing_slash_in_existing_dirpaths def remove(self, path, recursive=True, skip_trash=True): if not self.exists(path): return False self.conn.files_delete_v2(path) return True @accept_trailing_slash def mkdir(self, path, parents=True, raise_if_exists=False): if self.exists(path): if not self.isdir(path): raise luigi.target.NotADirectory() elif raise_if_exists: raise luigi.target.FileAlreadyExists() else: return self.conn.files_create_folder_v2(path) @accept_trailing_slash_in_existing_dirpaths def isdir(self, path): if path == "/": return True try: md = self.conn.files_get_metadata(path) return isinstance(md, dropbox.files.FolderMetadata) except dropbox.exceptions.ApiError as e: if isinstance(e.error.get_path(), dropbox.files.LookupError): return False else: raise e @accept_trailing_slash_in_existing_dirpaths def listdir(self, path, **kwargs): dirs = [] lister = self.conn.files_list_folder(path, recursive=True, **kwargs) dirs.extend(lister.entries) while lister.has_more: lister = self.conn.files_list_folder_continue(lister.cursor) dirs.extend(lister.entries) return [d.path_display for d in dirs] @accept_trailing_slash_in_existing_dirpaths def move(self, path, dest): self.conn.files_move_v2(from_path=path, to_path=dest) @accept_trailing_slash_in_existing_dirpaths def copy(self, path, dest): self.conn.files_copy_v2(from_path=path, to_path=dest) def download_as_bytes(self, path): metadata, response = self.conn.files_download(path) return response.content def upload(self, tmp_path, dest_path): with open(tmp_path, "rb") as f: file_size = os.path.getsize(tmp_path) CHUNK_SIZE = 4 * 1000 * 1000 upload_session_start_result = self.conn.files_upload_session_start(f.read(CHUNK_SIZE)) commit = dropbox.files.CommitInfo(path=dest_path) cursor = dropbox.files.UploadSessionCursor(session_id=upload_session_start_result.session_id, offset=f.tell()) if f.tell() >= file_size: self.conn.files_upload_session_finish(f.read(CHUNK_SIZE), cursor, commit) return while f.tell() < file_size: if (file_size - f.tell()) <= CHUNK_SIZE: self.conn.files_upload_session_finish(f.read(CHUNK_SIZE), cursor, commit) else: self.conn.files_upload_session_append_v2(f.read(CHUNK_SIZE), cursor) cursor.offset = f.tell() def _exists_and_is_dir(self, path): """ Auxiliary method, used by the 'accept_trailing_slash' and 'accept_trailing_slash_in_existing_dirpaths' decorators :param path: a Dropbox path that does NOT ends with a '/' (even if it is a directory) """ if path == "/": return True try: md = self.conn.files_get_metadata(path) is_dir = isinstance(md, dropbox.files.FolderMetadata) return is_dir except dropbox.exceptions.ApiError: return False class ReadableDropboxFile: def __init__(self, path, client): """ Represents a file inside the Dropbox cloud which will be read :param str path: Dropbpx path of the file to be read (always starting with /) :param DropboxClient client: a DropboxClient object (initialized with a valid token) """ self.path = path self.client = client self.download_file_location = os.path.join(tempfile.mkdtemp(prefix=str(time.time())), ntpath.basename(path)) self.fid = None self.closed = False def read(self): return self.client.download_as_bytes(self.path) def __enter__(self): return self def __exit__(self, exc_type, exc, traceback): self.close() def __del__(self): self.close() if os.path.exists(self.download_file_location): os.remove(self.download_file_location) def close(self): self.closed = True def readable(self): return True def writable(self): return False def seekable(self): return False class AtomicWritableDropboxFile(AtomicLocalFile): def __init__(self, path, client): """ Represents a file that will be created inside the Dropbox cloud :param str path: Destination path inside Dropbox :param DropboxClient client: a DropboxClient object (initialized with a valid token, for the desired account) """ super(AtomicWritableDropboxFile, self).__init__(path) self.path = path self.client = client def move_to_final_destination(self): """ After editing the file locally, this function uploads it to the Dropbox cloud """ self.client.upload(self.tmp_path, self.path) class DropboxTarget(FileSystemTarget): """ A Dropbox filesystem target. """ def __init__(self, path, token, format=None, user_agent="Luigi", root_namespace_id=None): """ Create an Dropbox Target for storing data in a dropbox.com account **About the path parameter** The path must start with '/' and should not end with '/' (even if it is a directory). The path must not contain adjacent slashes ('/files//img.jpg' is an invalid path) If the app has 'App folder' access, then / will refer to this app folder (which mean that there is no need to prepend the name of the app to the path) Otherwise, if the app has 'full access', then / will refer to the root of the Dropbox folder **About the token parameter:** The Dropbox target requires a valid OAuth2 token as a parameter (which means that a `Dropbox API app `_ must be created. This app can have 'App folder' access or 'Full Dropbox', as desired). Information about generating the token can be read here: - https://dropbox-sdk-python.readthedocs.io/en/latest/api/oauth.html#dropbox.oauth.DropboxOAuth2Flow - https://blogs.dropbox.com/developers/2014/05/generate-an-access-token-for-your-own-account/ :param str path: Remote path in Dropbox (starting with '/'). :param str token: a valid OAuth2 Dropbox token. :param luigi.Format format: the luigi format to use (e.g. `luigi.format.Nop`) :param str root_namespace_id: Root namespace ID for interacting with Team Spaces """ super(DropboxTarget, self).__init__(path) if not token: raise ValueError("The token parameter must contain a valid Dropbox Oauth2 Token") self.path = path self.token = token self.client = DropboxClient(token, user_agent, root_namespace_id) self.format = format or luigi.format.get_default_format() def __str__(self): return self.path @property def fs(self): return self.client @contextmanager def temporary_path(self): tmp_dir = tempfile.mkdtemp() num = random.randrange(0, 10_000_000_000) temp_path = "{}{}luigi-tmp-{:010}{}".format(tmp_dir, os.sep, num, ntpath.basename(self.path)) yield temp_path # We won't reach here if there was an user exception. self.fs.upload(temp_path, self.path) def open(self, mode): if mode not in ("r", "w"): raise ValueError("Unsupported open mode '%s'" % mode) if mode == "r": return self.format.pipe_reader(ReadableDropboxFile(self.path, self.client)) else: return self.format.pipe_writer(AtomicWritableDropboxFile(self.path, self.client)) ================================================ FILE: luigi/contrib/ecs.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2015 Outlier Bio, LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ EC2 Container Service wrapper for Luigi From the AWS website: Amazon EC2 Container Service (ECS) is a highly scalable, high performance container management service that supports Docker containers and allows you to easily run applications on a managed cluster of Amazon EC2 instances. To use ECS, you create a taskDefinition_ JSON that defines the `docker run`_ command for one or more containers in a task or service, and then submit this JSON to the API to run the task. This `boto3-powered`_ wrapper allows you to create Luigi Tasks to submit ECS ``taskDefinition`` s. You can either pass a dict (mapping directly to the ``taskDefinition`` JSON) OR an Amazon Resource Name (arn) for a previously registered ``taskDefinition``. Requires: - boto3 package - Amazon AWS credentials discoverable by boto3 (e.g., by using ``aws configure`` from awscli_) - A running ECS cluster (see `ECS Get Started`_) Written and maintained by Jake Feala (@jfeala) for Outlier Bio (@outlierbio) .. _`docker run`: https://docs.docker.com/reference/commandline/run .. _taskDefinition: http://docs.aws.amazon.com/AmazonECS/latest/developerguide/task_defintions.html .. _`boto3-powered`: https://boto3.readthedocs.io .. _awscli: https://aws.amazon.com/cli .. _`ECS Get Started`: http://docs.aws.amazon.com/AmazonECS/latest/developerguide/ECS_GetStarted.html """ import copy import logging import time import luigi logger = logging.getLogger("luigi-interface") try: import boto3 client = boto3.client("ecs") except ImportError: logger.warning("boto3 is not installed. ECSTasks require boto3") POLL_TIME = 2 def _get_task_statuses(task_ids, cluster): """ Retrieve task statuses from ECS API Returns list of {RUNNING|PENDING|STOPPED} for each id in task_ids """ response = client.describe_tasks(tasks=task_ids, cluster=cluster) # Error checking if response["failures"] != []: raise Exception("There were some failures:\n{0}".format(response["failures"])) status_code = response["ResponseMetadata"]["HTTPStatusCode"] if status_code != 200: msg = "Task status request received status code {0}:\n{1}" raise Exception(msg.format(status_code, response)) return [t["lastStatus"] for t in response["tasks"]] def _track_tasks(task_ids, cluster): """Poll task status until STOPPED""" while True: statuses = _get_task_statuses(task_ids, cluster) if all([status == "STOPPED" for status in statuses]): logger.info("ECS tasks {0} STOPPED".format(",".join(task_ids))) break time.sleep(POLL_TIME) logger.debug("ECS task status for tasks {0}: {1}".format(task_ids, statuses)) class ECSTask(luigi.Task): """ Base class for an Amazon EC2 Container Service Task Amazon ECS requires you to register "tasks", which are JSON descriptions for how to issue the ``docker run`` command. This Luigi Task can either run a pre-registered ECS taskDefinition, OR register the task on the fly from a Python dict. :param task_def_arn: pre-registered task definition ARN (Amazon Resource Name), of the form:: arn:aws:ecs:::task-definition/: :param task_def: dict describing task in taskDefinition JSON format, for example:: task_def = { 'family': 'hello-world', 'volumes': [], 'containerDefinitions': [ { 'memory': 1, 'essential': True, 'name': 'hello-world', 'image': 'ubuntu', 'command': ['/bin/echo', 'hello world'] } ] } :param cluster: str defining the ECS cluster to use. When this is not defined it will use the default one. """ task_def_arn = luigi.OptionalParameter(default=None) task_def = luigi.OptionalParameter(default=None) cluster = luigi.Parameter(default="default") @property def ecs_task_ids(self): """Expose the ECS task ID""" if hasattr(self, "_task_ids"): return self._task_ids @property def command(self): """ Command passed to the containers Override to return list of dicts with keys 'name' and 'command', describing the container names and commands to pass to the container. These values will be specified in the `containerOverrides` property of the `overrides` parameter passed to the runTask API. Example:: [ { 'name': 'myContainer', 'command': ['/bin/sleep', '60'] } ] """ pass @staticmethod def update_container_overrides_command(container_overrides, command): """ Update a list of container overrides with the specified command. The specified command will take precedence over any existing commands in `container_overrides` for the same container name. If no existing command yet exists in `container_overrides` for the specified command, it will be added. """ for colliding_override in filter(lambda x: x["name"] == command["name"], container_overrides): colliding_override["command"] = command["command"] break else: container_overrides.append(command) @property def combined_overrides(self): """ Return single dict combining any provided `overrides` parameters. This is used to allow custom `overrides` parameters to be specified in `self.run_task_kwargs` while ensuring that the values specified in `self.command` are honored in `containerOverrides`. """ overrides = copy.deepcopy(self.run_task_kwargs.get("overrides", {})) if self.command: if "containerOverrides" in overrides: for command in self.command: self.update_container_overrides_command(overrides["containerOverrides"], command) else: overrides["containerOverrides"] = self.command return overrides @property def run_task_kwargs(self): """ Additional keyword arguments to be provided to ECS runTask API. Override this property in a subclass to provide additional parameters such as `network_configuration`, `launchType`, etc. If the returned `dict` includes an `overrides` value with a nested `containerOverrides` array defining one or more container `command` values, prior to calling `run_task` they will be combined with and superseded by any colliding values specified separately in the `command` property. Example:: { 'launchType': 'FARGATE', 'platformVersion': '1.4.0', 'networkConfiguration': { 'awsvpcConfiguration': { 'subnets': [ 'subnet-01234567890abcdef', 'subnet-abcdef01234567890' ], 'securityGroups': [ 'sg-abcdef01234567890', ], 'assignPublicIp': 'ENABLED' } }, 'overrides': { 'ephemeralStorage': { 'sizeInGiB': 30 } } } """ return {} def run(self): if (not self.task_def and not self.task_def_arn) or (self.task_def and self.task_def_arn): raise ValueError(("Either (but not both) a task_def (dict) ortask_def_arn (string) must be assigned")) if not self.task_def_arn: # Register the task and get assigned taskDefinition ID (arn) response = client.register_task_definition(**self.task_def) self.task_def_arn = response["taskDefinition"]["taskDefinitionArn"] run_task_kwargs = self.run_task_kwargs run_task_kwargs.update( { "taskDefinition": self.task_def_arn, "cluster": self.cluster, "overrides": self.combined_overrides, } ) # Submit the task to AWS ECS and get assigned task ID # (list containing 1 string) response = client.run_task(**run_task_kwargs) if response["failures"]: raise Exception(", ".join(["fail to run task {0} reason: {1}".format(failure["arn"], failure["reason"]) for failure in response["failures"]])) self._task_ids = [task["taskArn"] for task in response["tasks"]] # Wait on task completion _track_tasks(self._task_ids, self.cluster) ================================================ FILE: luigi/contrib/esindex.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ Support for Elasticsearch (1.0.0 or newer). Provides an :class:`ElasticsearchTarget` and a :class:`CopyToIndex` template task. Modeled after :class:`luigi.contrib.rdbms.CopyToTable`. A minimal example (assuming elasticsearch is running on localhost:9200): .. code-block:: python class ExampleIndex(CopyToIndex): index = 'example' def docs(self): return [{'_id': 1, 'title': 'An example document.'}] if __name__ == '__main__': task = ExampleIndex() luigi.build([task], local_scheduler=True) All options: .. code-block:: python class ExampleIndex(CopyToIndex): host = 'localhost' port = 9200 index = 'example' doc_type = 'default' purge_existing_index = True marker_index_hist_size = 1 def docs(self): return [{'_id': 1, 'title': 'An example document.'}] if __name__ == '__main__': task = ExampleIndex() luigi.build([task], local_scheduler=True) `Host`, `port`, `index`, `doc_type` parameters are standard elasticsearch. `purge_existing_index` will delete the index, whenever an update is required. This is useful, when one deals with "dumps" that represent the whole data, not just updates. `marker_index_hist_size` sets the maximum number of entries in the 'marker' index: * 0 (default) keeps all updates, * 1 to only remember the most recent update to the index. This can be useful, if an index needs to recreated, even though the corresponding indexing task has been run sometime in the past - but a later indexing task might have altered the index in the meantime. There are a two luigi `luigi.cfg` configuration options: .. code-block:: ini [elasticsearch] marker-index = update_log marker-doc-type = entry """ # pylint: disable=F0401,E1101,C0103 import abc import datetime import hashlib import itertools import json import logging import luigi logger = logging.getLogger("luigi-interface") try: import elasticsearch if elasticsearch.__version__ < (1, 0, 0): logger.warning("This module works with elasticsearch 1.0.0 or newer only.") from elasticsearch.connection import Urllib3HttpConnection from elasticsearch.helpers import bulk except ImportError: logger.warning("Loading esindex module without elasticsearch installed. Will crash at runtime if esindex functionality is used.") class ElasticsearchTarget(luigi.Target): """Target for a resource in Elasticsearch.""" marker_index = luigi.configuration.get_config().get("elasticsearch", "marker-index", "update_log") marker_doc_type = luigi.configuration.get_config().get("elasticsearch", "marker-doc-type", "entry") def __init__(self, host, port, index, doc_type, update_id, marker_index_hist_size=0, http_auth=None, timeout=10, extra_elasticsearch_args=None): """ :param host: Elasticsearch server host :type host: str :param port: Elasticsearch server port :type port: int :param index: index name :type index: str :param doc_type: doctype name :type doc_type: str :param update_id: an identifier for this data set :type update_id: str :param marker_index_hist_size: list of changes to the index to remember :type marker_index_hist_size: int :param timeout: Elasticsearch connection timeout :type timeout: int :param extra_elasticsearch_args: extra args for Elasticsearch :type Extra: dict """ if extra_elasticsearch_args is None: extra_elasticsearch_args = {} self.host = host self.port = port self.http_auth = http_auth self.index = index self.doc_type = doc_type self.update_id = update_id self.marker_index_hist_size = marker_index_hist_size self.timeout = timeout self.extra_elasticsearch_args = extra_elasticsearch_args self.es = elasticsearch.Elasticsearch( connection_class=Urllib3HttpConnection, host=self.host, port=self.port, http_auth=self.http_auth, timeout=self.timeout, **self.extra_elasticsearch_args, ) def marker_index_document_id(self): """ Generate an id for the indicator document. """ params = "%s:%s:%s" % (self.index, self.doc_type, self.update_id) return hashlib.sha1(params.encode("utf-8")).hexdigest() def touch(self): """ Mark this update as complete. The document id would be sufficient but, for documentation, we index the parameters `update_id`, `target_index`, `target_doc_type` and `date` as well. """ self.create_marker_index() self.es.index( index=self.marker_index, doc_type=self.marker_doc_type, id=self.marker_index_document_id(), body={"update_id": self.update_id, "target_index": self.index, "target_doc_type": self.doc_type, "date": datetime.datetime.now()}, ) self.es.indices.flush(index=self.marker_index) self.ensure_hist_size() def exists(self): """ Test, if this task has been run. """ try: self.es.get(index=self.marker_index, doc_type=self.marker_doc_type, id=self.marker_index_document_id()) return True except elasticsearch.NotFoundError: logger.debug("Marker document not found.") except elasticsearch.ElasticsearchException as err: logger.warn(err) return False def create_marker_index(self): """ Create the index that will keep track of the tasks if necessary. """ if not self.es.indices.exists(index=self.marker_index): self.es.indices.create(index=self.marker_index) def ensure_hist_size(self): """ Shrink the history of updates for a `index/doc_type` combination down to `self.marker_index_hist_size`. """ if self.marker_index_hist_size == 0: return result = self.es.search( index=self.marker_index, doc_type=self.marker_doc_type, body={"query": {"term": {"target_index": self.index}}}, sort=("date:desc",) ) for i, hit in enumerate(result.get("hits").get("hits"), start=1): if i > self.marker_index_hist_size: marker_document_id = hit.get("_id") self.es.delete(id=marker_document_id, index=self.marker_index, doc_type=self.marker_doc_type) self.es.indices.flush(index=self.marker_index) class CopyToIndex(luigi.Task): """ Template task for inserting a data set into Elasticsearch. Usage: 1. Subclass and override the required `index` attribute. 2. Implement a custom `docs` method, that returns an iterable over the documents. A document can be a JSON string, e.g. from a newline-delimited JSON (ldj) file (default implementation) or some dictionary. Optional attributes: * doc_type (default), * host (localhost), * port (9200), * settings ({'settings': {}}) * mapping (None), * chunk_size (2000), * raise_on_error (True), * purge_existing_index (False), * marker_index_hist_size (0) If settings are defined, they are only applied at index creation time. """ @property def host(self): """ ES hostname. """ return "localhost" @property def port(self): """ ES port. """ return 9200 @property def http_auth(self): """ ES optional http auth information as either ‘:’ separated string or a tuple, e.g. `('user', 'pass')` or `"user:pass"`. """ return None @property @abc.abstractmethod def index(self): """ The target index. May exist or not. """ return None @property def doc_type(self): """ The target doc_type. """ return "default" @property def mapping(self): """ Dictionary with custom mapping or `None`. """ return None @property def settings(self): """ Settings to be used at index creation time. """ return {"settings": {}} @property def chunk_size(self): """ Single API call for this number of docs. """ return 2000 @property def raise_on_error(self): """ Whether to fail fast. """ return True @property def purge_existing_index(self): """ Whether to delete the `index` completely before any indexing. """ return False @property def marker_index_hist_size(self): """ Number of event log entries in the marker index. 0: unlimited. """ return 0 @property def timeout(self): """ Timeout. """ return 10 @property def extra_elasticsearch_args(self): """ Extra arguments to pass to the Elasticsearch constructor """ return {} def docs(self): """ Return the documents to be indexed. Beside the user defined fields, the document may contain an `_index`, `_type` and `_id`. """ with self.input().open("r") as fobj: for line in fobj: yield line # everything below will rarely have to be overridden def _docs(self): """ Since `self.docs` may yield documents that do not explicitly contain `_index` or `_type`, add those attributes here, if necessary. """ iterdocs = iter(self.docs()) first = next(iterdocs) needs_parsing = False if isinstance(first, str): needs_parsing = True elif isinstance(first, dict): pass else: raise RuntimeError("Document must be either JSON strings or dict.") for doc in itertools.chain([first], iterdocs): if needs_parsing: doc = json.loads(doc) if "_index" not in doc: doc["_index"] = self.index if "_type" not in doc: doc["_type"] = self.doc_type yield doc def _init_connection(self): return elasticsearch.Elasticsearch( connection_class=Urllib3HttpConnection, host=self.host, port=self.port, http_auth=self.http_auth, timeout=self.timeout, **self.extra_elasticsearch_args, ) def create_index(self): """ Override to provide code for creating the target index. By default it will be created without any special settings or mappings. """ es = self._init_connection() if not es.indices.exists(index=self.index): es.indices.create(index=self.index, body=self.settings) def delete_index(self): """ Delete the index, if it exists. """ es = self._init_connection() if es.indices.exists(index=self.index): es.indices.delete(index=self.index) def update_id(self): """ This id will be a unique identifier for this indexing task. """ return self.task_id def output(self): """ Returns a ElasticsearchTarget representing the inserted dataset. Normally you don't override this. """ return ElasticsearchTarget( host=self.host, port=self.port, http_auth=self.http_auth, index=self.index, doc_type=self.doc_type, update_id=self.update_id(), marker_index_hist_size=self.marker_index_hist_size, timeout=self.timeout, extra_elasticsearch_args=self.extra_elasticsearch_args, ) def run(self): """ Run task, namely: * purge existing index, if requested (`purge_existing_index`), * create the index, if missing, * apply mappings, if given, * set refresh interval to -1 (disable) for performance reasons, * bulk index in batches of size `chunk_size` (2000), * set refresh interval to 1s, * refresh Elasticsearch, * create entry in marker index. """ if self.purge_existing_index: self.delete_index() self.create_index() es = self._init_connection() if self.mapping: es.indices.put_mapping(index=self.index, doc_type=self.doc_type, body=self.mapping) es.indices.put_settings({"index": {"refresh_interval": "-1"}}, index=self.index) bulk(es, self._docs(), chunk_size=self.chunk_size, raise_on_error=self.raise_on_error) es.indices.put_settings({"index": {"refresh_interval": "1s"}}, index=self.index) es.indices.refresh() self.output().touch() ================================================ FILE: luigi/contrib/external_daily_snapshot.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2017 Spotify AB. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # from __future__ import annotations import datetime import logging from typing import Any import luigi logger = logging.getLogger("luigi-interface") class ExternalDailySnapshot(luigi.ExternalTask): """ Abstract class containing a helper method to fetch the latest snapshot. Example:: class MyTask(luigi.Task): def requires(self): return PlaylistContent.latest() All tasks subclassing :class:`ExternalDailySnapshot` must have a :class:`luigi.DateParameter` named ``date``. You can also provide additional parameters to the class and also configure lookback size. Example:: ServiceLogs.latest(service="radio", lookback=21) """ date = luigi.DateParameter() __cache: list[Any] = [] @classmethod def latest(cls, *args, **kwargs): """This is cached so that requires() is deterministic.""" date = kwargs.pop("date", datetime.date.today()) lookback = kwargs.pop("lookback", 14) # hashing kwargs deterministically would be hard. Let's just lookup by equality key = (cls, args, kwargs, lookback, date) for k, v in ExternalDailySnapshot.__cache: if k == key: return v val = cls.__latest(date, lookback, args, kwargs) ExternalDailySnapshot.__cache.append((key, val)) return val @classmethod def __latest(cls, date, lookback, args, kwargs): assert lookback > 0 t = None for i in range(lookback): d = date - datetime.timedelta(i) t = cls(date=d, *args, **kwargs) if t.complete(): return t logger.debug("Could not find last dump for %s (looked back %d days)", cls.__name__, lookback) return t ================================================ FILE: luigi/contrib/external_program.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2016 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ Template tasks for running external programs as luigi tasks. This module is primarily intended for when you need to call a single external program or shell script, and it's enough to specify program arguments and environment variables. If you need to run multiple commands, chain them together or pipe output from one command to the next, you're probably better off using something like `plumbum`_, and wrapping plumbum commands in normal luigi :py:class:`~luigi.task.Task` s. .. _plumbum: https://plumbum.readthedocs.io/ """ import logging import os import re import signal import subprocess import sys import tempfile from contextlib import contextmanager from multiprocessing import Process from time import sleep import luigi from luigi.parameter import ParameterVisibility logger = logging.getLogger("luigi-interface") class ExternalProgramTask(luigi.Task): """ Template task for running an external program in a subprocess The program is run using :py:class:`subprocess.Popen`, with ``args`` passed as a list, generated by :py:meth:`program_args` (where the first element should be the executable). See :py:class:`subprocess.Popen` for details. Your must override :py:meth:`program_args` to specify the arguments you want, and you can optionally override :py:meth:`program_environment` if you want to control the environment variables (see :py:class:`ExternalPythonProgramTask` for an example). By default, the output (stdout and stderr) of the run external program is being captured and displayed after the execution has ended. This behaviour can be overridden by passing ``--capture-output False`` """ capture_output = luigi.BoolParameter(default=True, significant=False, positional=False) stream_for_searching_tracking_url = luigi.parameter.ChoiceParameter( var_type=str, choices=["none", "stdout", "stderr"], default="none", significant=False, positional=False, visibility=ParameterVisibility.HIDDEN, description="Stream for searching tracking URL", ) """ Used for defining which stream should be tracked for URL, may be set to 'stdout', 'stderr' or 'none'. Default value is 'none', so URL tracking is not performed. """ tracking_url_pattern = luigi.OptionalParameter( default=None, significant=False, positional=False, visibility=ParameterVisibility.HIDDEN, description="Regex pattern used for searching URL in the logs of the external program", ) """ Regex pattern used for searching URL in the logs of the external program. If a log line matches the regex, the first group in the matching is set as the tracking URL for the job in the web UI. Example: 'Job UI is here: (https?://.*)'. Default value is None, so URL tracking is not performed. """ def program_args(self): """ Override this method to map your task parameters to the program arguments :return: list to pass as ``args`` to :py:class:`subprocess.Popen` """ raise NotImplementedError def program_environment(self): """ Override this method to control environment variables for the program :return: dict mapping environment variable names to values """ env = os.environ.copy() return env @property def always_log_stderr(self): """ When True, stderr will be logged even if program execution succeeded Override to False to log stderr only when program execution fails. """ return True def _clean_output_file(self, file_object): file_object.seek(0) return "".join(map(lambda s: s.decode("utf-8"), file_object.readlines())) def build_tracking_url(self, logs_output): """ This method is intended for transforming pattern match in logs to an URL :param logs_output: Found match of `self.tracking_url_pattern` :return: a tracking URL for the task """ return logs_output def run(self): args = list(map(str, self.program_args())) logger.info("Running command: %s", " ".join(args)) env = self.program_environment() kwargs = {"env": env} tmp_stdout, tmp_stderr = None, None if self.capture_output: tmp_stdout, tmp_stderr = tempfile.TemporaryFile(), tempfile.TemporaryFile() kwargs.update({"stdout": tmp_stdout, "stderr": tmp_stderr}) try: if self.stream_for_searching_tracking_url != "none" and self.tracking_url_pattern is not None: with self._proc_with_tracking_url_context(proc_args=args, proc_kwargs=kwargs) as proc: proc.wait() else: proc = subprocess.Popen(args, **kwargs) with ExternalProgramRunContext(proc): proc.wait() success = proc.returncode == 0 if self.capture_output: stdout = self._clean_output_file(tmp_stdout) stderr = self._clean_output_file(tmp_stderr) if stdout: logger.info("Program stdout:\n{}".format(stdout)) if stderr: if self.always_log_stderr or not success: logger.info("Program stderr:\n{}".format(stderr)) else: stdout, stderr = None, None if not success: raise ExternalProgramRunError("Program failed with return code={}:".format(proc.returncode), args, env=env, stdout=stdout, stderr=stderr) finally: if self.capture_output: tmp_stderr.close() tmp_stdout.close() @contextmanager def _proc_with_tracking_url_context(self, proc_args, proc_kwargs): time_to_sleep = 0.5 file_to_write = proc_kwargs.get(self.stream_for_searching_tracking_url) proc_kwargs.update({self.stream_for_searching_tracking_url: subprocess.PIPE}) main_proc = subprocess.Popen(proc_args, **proc_kwargs) pipe_to_read = main_proc.stderr if self.stream_for_searching_tracking_url == "stderr" else main_proc.stdout def _track_url_by_pattern(): """ Scans the pipe looking for a passed pattern, if the pattern is found, `set_tracking_url` callback is sent. If tmp_stdout is passed, also appends lines to this file. """ pattern = re.compile(self.tracking_url_pattern) for new_line in iter(pipe_to_read.readline, ""): if new_line: if file_to_write: file_to_write.write(new_line) match = re.search(pattern, new_line.decode("utf-8")) if match: self.set_tracking_url(self.build_tracking_url(match.group(1))) else: file_to_write.flush() sleep(time_to_sleep) track_proc = Process(target=_track_url_by_pattern) try: track_proc.start() with ExternalProgramRunContext(main_proc): yield main_proc finally: # need to wait a bit to let the subprocess read the last lines track_proc.join(time_to_sleep * 2) if track_proc.is_alive(): track_proc.terminate() pipe_to_read.close() class ExternalProgramRunContext: def __init__(self, proc): self.proc = proc def __enter__(self): self.__old_signal = signal.getsignal(signal.SIGTERM) signal.signal(signal.SIGTERM, self.kill_job) return self def __exit__(self, exc_type, exc_val, exc_tb): if exc_type is KeyboardInterrupt: self.kill_job() signal.signal(signal.SIGTERM, self.__old_signal) def kill_job(self, captured_signal=None, stack_frame=None): self.proc.kill() if captured_signal is not None: # adding 128 gives the exit code corresponding to a signal sys.exit(128 + captured_signal) class ExternalProgramRunError(RuntimeError): def __init__(self, message, args, env=None, stdout=None, stderr=None): super(ExternalProgramRunError, self).__init__(message, args, env, stdout, stderr) self.message = message self.args = args self.env = env self.out = stdout self.err = stderr def __str__(self): info = self.message info += "\nCOMMAND: {}".format(" ".join(self.args)) info += "\nSTDOUT: {}".format(self.out or "[empty]") info += "\nSTDERR: {}".format(self.err or "[empty]") env_string = None if self.env: env_string = " ".join(["=".join([k, "'{}'".format(v)]) for k, v in self.env.items()]) info += "\nENVIRONMENT: {}".format(env_string or "[empty]") # reset terminal color in case the ENVIRONMENT changes colors info += "\033[m" return info class ExternalPythonProgramTask(ExternalProgramTask): """ Template task for running an external Python program in a subprocess Simple extension of :py:class:`ExternalProgramTask`, adding two :py:class:`luigi.parameter.Parameter` s for setting a virtualenv and for extending the ``PYTHONPATH``. """ virtualenv = luigi.OptionalParameter( default=None, positional=False, description="path to the virtualenv directory to use. It should point to " "the directory containing the ``bin/activate`` file used for " "enabling the virtualenv.", ) extra_pythonpath = luigi.OptionalParameter( default=None, positional=False, description="extend the search path for modules by prepending this value to the ``PYTHONPATH`` environment variable." ) def program_environment(self): env = super(ExternalPythonProgramTask, self).program_environment() if self.extra_pythonpath: pythonpath = ":".join([self.extra_pythonpath, env.get("PYTHONPATH", "")]) env.update({"PYTHONPATH": pythonpath}) if self.virtualenv: # Make the same changes to the env that a normal venv/bin/activate script would path = ":".join(["{}/bin".format(self.virtualenv), env.get("PATH", "")]) env.update({"PATH": path, "VIRTUAL_ENV": self.virtualenv}) # remove PYTHONHOME env variable, if it exists env.pop("PYTHONHOME", None) return env ================================================ FILE: luigi/contrib/ftp.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ This library is a wrapper of ftplib or pysftp. It is convenient to move data from/to (S)FTP servers. There is an example on how to use it (example/ftp_experiment_outputs.py) You can also find unittest for each class. Be aware that normal ftp does not provide secure communication. """ import datetime import ftplib import io import logging import os import random import tempfile import luigi import luigi.format import luigi.local_target import luigi.target from luigi.format import FileWrapper logger = logging.getLogger("luigi-interface") class RemoteFileSystem(luigi.target.FileSystem): def __init__(self, host, username=None, password=None, port=None, tls=False, timeout=60, sftp=False, pysftp_conn_kwargs=None): self.host = host self.username = username self.password = password self.tls = tls self.timeout = timeout self.sftp = sftp self.pysftp_conn_kwargs = pysftp_conn_kwargs or {} if port is None: if self.sftp: self.port = 22 else: self.port = 21 else: self.port = port def _connect(self): """ Log in to ftp. """ if self.sftp: self._sftp_connect() else: self._ftp_connect() def _sftp_connect(self): try: import pysftp except ImportError: logger.warning("Please install pysftp to use SFTP.") self.conn = pysftp.Connection(self.host, username=self.username, password=self.password, port=self.port, **self.pysftp_conn_kwargs) def _ftp_connect(self): if self.tls: self.conn = ftplib.FTP_TLS() else: self.conn = ftplib.FTP() self.conn.connect(self.host, self.port, timeout=self.timeout) self.conn.login(self.username, self.password) if self.tls: self.conn.prot_p() def _close(self): """ Close ftp connection. """ if self.sftp: self._sftp_close() else: self._ftp_close() def _sftp_close(self): self.conn.close() def _ftp_close(self): self.conn.quit() def exists(self, path, mtime=None): """ Return `True` if file or directory at `path` exist, False otherwise. Additional check on modified time when mtime is passed in. Return False if the file's modified time is older mtime. """ self._connect() if self.sftp: exists = self._sftp_exists(path, mtime) else: exists = self._ftp_exists(path, mtime) self._close() return exists def _sftp_exists(self, path, mtime): exists = False if mtime: exists = self.conn.stat(path).st_mtime > mtime elif self.conn.exists(path): exists = True return exists def _ftp_exists(self, path, mtime): dirname, fn = os.path.split(path) files = self.conn.nlst(dirname) exists = False if path in files or fn in files: if mtime: mdtm = self.conn.sendcmd("MDTM " + path) modified = datetime.datetime.strptime(mdtm[4:], "%Y%m%d%H%M%S") exists = modified > mtime else: exists = True return exists def remove(self, path, recursive=True): """ Remove file or directory at location ``path``. :param path: a path within the FileSystem to remove. :type path: str :param recursive: if the path is a directory, recursively remove the directory and all of its descendants. Defaults to ``True``. :type recursive: bool """ self._connect() if self.sftp: self._sftp_remove(path, recursive) else: self._ftp_remove(path, recursive) self._close() def _sftp_remove(self, path, recursive): if self.conn.isfile(path): self.conn.unlink(path) else: if not recursive: raise RuntimeError("Path is not a regular file, and recursive option is not set") directories = [] # walk the tree, and execute call backs when files, # directories and unknown types are encountered # files must be removed first. then directories can be removed # after the files are gone. self.conn.walktree(path, self.conn.unlink, directories.append, self.conn.unlink) for directory in reversed(directories): self.conn.rmdir(directory) self.conn.rmdir(path) def _ftp_remove(self, path, recursive): if recursive: self._rm_recursive(self.conn, path) else: try: # try delete file self.conn.delete(path) except ftplib.all_errors: # it is a folder, delete it self.conn.rmd(path) def _rm_recursive(self, ftp, path): """ Recursively delete a directory tree on a remote server. Source: https://gist.github.com/artlogic/2632647 """ wd = ftp.pwd() # check if it is a file first, because some FTP servers don't return # correctly on ftp.nlst(file) try: ftp.cwd(path) except ftplib.all_errors: # this is a file, we will just delete the file ftp.delete(path) return try: names = ftp.nlst() except ftplib.all_errors: # some FTP servers complain when you try and list non-existent paths return for name in names: if os.path.split(name)[1] in (".", ".."): continue try: ftp.cwd(name) # if we can cwd to it, it's a folder ftp.cwd(wd) # don't try a nuke a folder we're in ftp.cwd(path) # then go back to where we were self._rm_recursive(ftp, name) except ftplib.all_errors: ftp.delete(name) try: ftp.cwd(wd) # do not delete the folder that we are in ftp.rmd(path) except ftplib.all_errors as e: print("_rm_recursive: Could not remove {0}: {1}".format(path, e)) def put(self, local_path, path, atomic=True): """ Put file from local filesystem to (s)FTP. """ self._connect() if self.sftp: self._sftp_put(local_path, path, atomic) else: self._ftp_put(local_path, path, atomic) self._close() def _sftp_put(self, local_path, path, atomic): normpath = os.path.normpath(path) directory = os.path.dirname(normpath) self.conn.makedirs(directory) if atomic: tmp_path = os.path.join(directory, "luigi-tmp-{:09d}".format(random.randrange(0, 10_000_000_000))) else: tmp_path = normpath self.conn.put(local_path, tmp_path) if atomic: self.conn.rename(tmp_path, normpath) def _ftp_put(self, local_path, path, atomic): normpath = os.path.normpath(path) folder = os.path.dirname(normpath) # create paths if do not exists for subfolder in folder.split(os.sep): if subfolder and subfolder not in self.conn.nlst(): self.conn.mkd(subfolder) self.conn.cwd(subfolder) # go back to ftp root folder self.conn.cwd("/") # random file name if atomic: tmp_path = folder + os.sep + "luigi-tmp-%09d" % random.randrange(0, 10_000_000_000) else: tmp_path = normpath self.conn.storbinary("STOR %s" % tmp_path, open(local_path, "rb")) if atomic: self.conn.rename(tmp_path, normpath) def get(self, path, local_path): """ Download file from (s)FTP to local filesystem. """ normpath = os.path.normpath(local_path) folder = os.path.dirname(normpath) if folder and not os.path.exists(folder): os.makedirs(folder) tmp_local_path = local_path + "-luigi-tmp-%09d" % random.randrange(0, 10_000_000_000) # download file self._connect() if self.sftp: self._sftp_get(path, tmp_local_path) else: self._ftp_get(path, tmp_local_path) self._close() os.replace(tmp_local_path, local_path) def _sftp_get(self, path, tmp_local_path): self.conn.get(path, tmp_local_path) def _ftp_get(self, path, tmp_local_path): self.conn.retrbinary("RETR %s" % path, open(tmp_local_path, "wb").write) def listdir(self, path="."): """ Gets an list of the contents of path in (s)FTP """ self._connect() if self.sftp: contents = self._sftp_listdir(path) else: contents = self._ftp_listdir(path) self._close() return contents def _sftp_listdir(self, path): return self.conn.listdir(remotepath=path) def _ftp_listdir(self, path): return self.conn.nlst(path) class AtomicFtpFile(luigi.target.AtomicLocalFile): """ Simple class that writes to a temp file and upload to ftp on close(). Also cleans up the temp file if close is not invoked. """ def __init__(self, fs, path): """ Initializes an AtomicFtpfile instance. :param fs: :param path: :type path: str """ self._fs = fs super(AtomicFtpFile, self).__init__(path) def move_to_final_destination(self): self._fs.put(self.tmp_path, self.path) @property def fs(self): return self._fs class RemoteTarget(luigi.target.FileSystemTarget): """ Target used for reading from remote files. The target is implemented using intermediate files on the local system. On Python2, these files may not be cleaned up. """ def __init__( self, path, host, format=None, username=None, password=None, port=None, mtime=None, tls=False, timeout=60, sftp=False, pysftp_conn_kwargs=None ): if format is None: format = luigi.format.get_default_format() self.path = path self.mtime = mtime self.format = format self.tls = tls self.timeout = timeout self.sftp = sftp self._fs = RemoteFileSystem(host, username, password, port, tls, timeout, sftp, pysftp_conn_kwargs) @property def fs(self): return self._fs def open(self, mode): """ Open the FileSystem target. This method returns a file-like object which can either be read from or written to depending on the specified mode. :param mode: the mode `r` opens the FileSystemTarget in read-only mode, whereas `w` will open the FileSystemTarget in write mode. Subclasses can implement additional options. :type mode: str """ if mode == "w": return self.format.pipe_writer(AtomicFtpFile(self._fs, self.path)) elif mode == "r": temppath = "{}-luigi-tmp-{:09d}".format(self.path.lstrip("/"), random.randrange(0, 10_000_000_000)) try: # store reference to the TemporaryDirectory because it will be removed on GC self.__temp_dir = tempfile.TemporaryDirectory(prefix="luigi-contrib-ftp_") except AttributeError: # TemporaryDirectory only available in Python3, use old behaviour in Python2 # this file will not be cleaned up automatically self.__tmp_path = os.path.join(tempfile.gettempdir(), "luigi-contrib-ftp", temppath) else: self.__tmp_path = os.path.join(self.__temp_dir.name, temppath) # download file to local self._fs.get(self.path, self.__tmp_path) return self.format.pipe_reader(FileWrapper(io.BufferedReader(io.FileIO(self.__tmp_path, "r")))) else: raise Exception("mode must be 'r' or 'w' (got: %s)" % mode) def exists(self): return self.fs.exists(self.path, self.mtime) def put(self, local_path, atomic=True): self.fs.put(local_path, self.path, atomic) def get(self, local_path): self.fs.get(self.path, local_path) ================================================ FILE: luigi/contrib/gcp.py ================================================ """ Common code for GCP (google cloud services) integration """ import logging logger = logging.getLogger("luigi-interface") try: import google.auth import httplib2 except ImportError: logger.warning( "Loading GCP module without the python packages httplib2, google-auth. \ This *could* crash at runtime if no other credentials are provided." ) def get_authenticate_kwargs(oauth_credentials=None, http_=None): """Returns a dictionary with keyword arguments for use with discovery Prioritizes oauth_credentials or a http client provided by the user If none provided, falls back to default credentials provided by google's command line utilities. If that also fails, tries using httplib2.Http() Used by `gcs.GCSClient` and `bigquery.BigQueryClient` to initiate the API Client """ if oauth_credentials: authenticate_kwargs = {"credentials": oauth_credentials} elif http_: authenticate_kwargs = {"http": http_} else: # neither http_ or credentials provided try: # try default credentials credentials, _ = google.auth.default() authenticate_kwargs = {"credentials": credentials} except google.auth.exceptions.DefaultCredentialsError: # try http using httplib2 authenticate_kwargs = {"http": httplib2.Http()} return authenticate_kwargs ================================================ FILE: luigi/contrib/gcs.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2015 Twitter Inc # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """luigi bindings for Google Cloud Storage""" import io import logging import mimetypes import os import tempfile import time from io import BytesIO from urllib.parse import urlsplit from tenacity import after_log, retry, retry_if_exception, retry_if_exception_type, stop_after_attempt, wait_exponential import luigi.target from luigi.contrib import gcp from luigi.format import FileWrapper logger = logging.getLogger("luigi-interface") # Retry when following errors happened RETRYABLE_ERRORS = None try: import httplib2 from googleapiclient import discovery, errors, http except ImportError: logger.warning( "Loading GCS module without the python packages googleapiclient & google-auth. \ This will crash at runtime if GCS functionality is used." ) else: RETRYABLE_ERRORS = (httplib2.HttpLib2Error, IOError) # Number of bytes to send/receive in each request. CHUNKSIZE = 10 * 1024 * 1024 # Mimetype to use if one can't be guessed from the file extension. DEFAULT_MIMETYPE = "application/octet-stream" # Time to sleep while waiting for eventual consistency to finish. EVENTUAL_CONSISTENCY_SLEEP_INTERVAL = 0.1 # Maximum number of sleeps for eventual consistency. EVENTUAL_CONSISTENCY_MAX_SLEEPS = 300 # Uri for batch requests GCS_BATCH_URI = "https://storage.googleapis.com/batch/storage/v1" # Retry configurations. For more details, see https://tenacity.readthedocs.io/en/latest/ def is_error_5xx(err): return isinstance(err, errors.HttpError) and err.resp.status >= 500 gcs_retry = retry( retry=(retry_if_exception(is_error_5xx) | retry_if_exception_type(RETRYABLE_ERRORS)), wait=wait_exponential(multiplier=1, min=1, max=10), stop=stop_after_attempt(5), reraise=True, after=after_log(logger, logging.WARNING), ) def _wait_for_consistency(checker): """Eventual consistency: wait until GCS reports something is true. This is necessary for e.g. create/delete where the operation might return, but won't be reflected for a bit. """ for _ in range(EVENTUAL_CONSISTENCY_MAX_SLEEPS): if checker(): return time.sleep(EVENTUAL_CONSISTENCY_SLEEP_INTERVAL) logger.warning("Exceeded wait for eventual GCS consistency - this may be abug in the library or something is terribly wrong.") class InvalidDeleteException(luigi.target.FileSystemException): pass class GCSClient(luigi.target.FileSystem): """An implementation of a FileSystem over Google Cloud Storage. There are several ways to use this class. By default it will use the app default credentials, as described at https://developers.google.com/identity/protocols/application-default-credentials . Alternatively, you may pass an google-auth credentials object. e.g. to use a service account:: credentials = google.auth.jwt.Credentials.from_service_account_info( '012345678912-ThisIsARandomServiceAccountEmail@developer.gserviceaccount.com', 'These are the contents of the p12 file that came with the service account', scope='https://www.googleapis.com/auth/devstorage.read_write') client = GCSClient(oauth_credentials=credentails) The chunksize parameter specifies how much data to transfer when downloading or uploading files. .. warning:: By default this class will use "automated service discovery" which will require a connection to the web. The google api client downloads a JSON file to "create" the library interface on the fly. If you want a more hermetic build, you can pass the contents of this file (currently found at https://www.googleapis.com/discovery/v1/apis/storage/v1/rest ) as the ``descriptor`` argument. """ def __init__(self, oauth_credentials=None, descriptor="", http_=None, chunksize=CHUNKSIZE, **discovery_build_kwargs): self.chunksize = chunksize authenticate_kwargs = gcp.get_authenticate_kwargs(oauth_credentials, http_) build_kwargs = authenticate_kwargs.copy() build_kwargs.update(discovery_build_kwargs) if descriptor: self.client = discovery.build_from_document(descriptor, **build_kwargs) else: build_kwargs.setdefault("cache_discovery", False) self.client = discovery.build("storage", "v1", **build_kwargs) def _path_to_bucket_and_key(self, path): (scheme, netloc, path, _, _) = urlsplit(path) assert scheme == "gs" path_without_initial_slash = path[1:] return netloc, path_without_initial_slash def _is_root(self, key): return len(key) == 0 or key == "/" def _add_path_delimiter(self, key): return key if key[-1:] == "/" else key + "/" @gcs_retry def _obj_exists(self, bucket, obj): try: self.client.objects().get(bucket=bucket, object=obj).execute() except errors.HttpError as ex: if ex.resp["status"] == "404": return False raise else: return True def _list_iter(self, bucket, prefix): request = self.client.objects().list(bucket=bucket, prefix=prefix) response = request.execute() while response is not None: for it in response.get("items", []): yield it request = self.client.objects().list_next(request, response) if request is None: break response = request.execute() @gcs_retry def _do_put(self, media, dest_path): bucket, obj = self._path_to_bucket_and_key(dest_path) request = self.client.objects().insert(bucket=bucket, name=obj, media_body=media) if not media.resumable(): return request.execute() response = None while response is None: status, response = request.next_chunk() if status: logger.debug("Upload progress: %.2f%%", 100 * status.progress()) _wait_for_consistency(lambda: self._obj_exists(bucket, obj)) return response def exists(self, path): bucket, obj = self._path_to_bucket_and_key(path) if self._obj_exists(bucket, obj): return True return self.isdir(path) def isdir(self, path): bucket, obj = self._path_to_bucket_and_key(path) if self._is_root(obj): try: self.client.buckets().get(bucket=bucket).execute() except errors.HttpError as ex: if ex.resp["status"] == "404": return False raise obj = self._add_path_delimiter(obj) if self._obj_exists(bucket, obj): return True # Any objects with this prefix resp = self.client.objects().list(bucket=bucket, prefix=obj, maxResults=20).execute() lst = next(iter(resp.get("items", [])), None) return bool(lst) def remove(self, path, recursive=True): (bucket, obj) = self._path_to_bucket_and_key(path) if self._is_root(obj): raise InvalidDeleteException("Cannot delete root of bucket at path {}".format(path)) if self._obj_exists(bucket, obj): self.client.objects().delete(bucket=bucket, object=obj).execute() _wait_for_consistency(lambda: not self._obj_exists(bucket, obj)) return True if self.isdir(path): if not recursive: raise InvalidDeleteException("Path {} is a directory. Must use recursive delete".format(path)) req = http.BatchHttpRequest(batch_uri=GCS_BATCH_URI) for it in self._list_iter(bucket, self._add_path_delimiter(obj)): req.add(self.client.objects().delete(bucket=bucket, object=it["name"])) req.execute() _wait_for_consistency(lambda: not self.isdir(path)) return True return False def put(self, filename, dest_path, mimetype=None, chunksize=None): chunksize = chunksize or self.chunksize resumable = os.path.getsize(filename) > 0 mimetype = mimetype or mimetypes.guess_type(dest_path)[0] or DEFAULT_MIMETYPE media = http.MediaFileUpload(filename, mimetype=mimetype, chunksize=chunksize, resumable=resumable) self._do_put(media, dest_path) def _forward_args_to_put(self, kwargs): return self.put(**kwargs) def put_multiple(self, filepaths, remote_directory, mimetype=None, chunksize=None, num_process=1): if isinstance(filepaths, str): raise ValueError("filenames must be a list of strings. If you want to put a single file, use the `put(self, filename, ...)` method") put_kwargs_list = [ { "filename": filepath, "dest_path": os.path.join(remote_directory, os.path.basename(filepath)), "mimetype": mimetype, "chunksize": chunksize, } for filepath in filepaths ] if num_process > 1: from contextlib import closing from multiprocessing import Pool with closing(Pool(num_process)) as p: return p.map(self._forward_args_to_put, put_kwargs_list) else: for put_kwargs in put_kwargs_list: self._forward_args_to_put(put_kwargs) def put_string(self, contents, dest_path, mimetype=None): mimetype = mimetype or mimetypes.guess_type(dest_path)[0] or DEFAULT_MIMETYPE assert isinstance(mimetype, str) if not isinstance(contents, bytes): contents = contents.encode("utf-8") media = http.MediaIoBaseUpload(BytesIO(contents), mimetype, resumable=bool(contents)) self._do_put(media, dest_path) def mkdir(self, path, parents=True, raise_if_exists=False): if self.exists(path): if raise_if_exists: raise luigi.target.FileAlreadyExists() elif not self.isdir(path): raise luigi.target.NotADirectory() else: return self.put_string(b"", self._add_path_delimiter(path), mimetype="text/plain") def copy(self, source_path, destination_path): src_bucket, src_obj = self._path_to_bucket_and_key(source_path) dest_bucket, dest_obj = self._path_to_bucket_and_key(destination_path) if self.isdir(source_path): src_prefix = self._add_path_delimiter(src_obj) dest_prefix = self._add_path_delimiter(dest_obj) source_path = self._add_path_delimiter(source_path) copied_objs = [] for obj in self.listdir(source_path): suffix = obj[len(source_path) :] self.client.objects().copy( sourceBucket=src_bucket, sourceObject=src_prefix + suffix, destinationBucket=dest_bucket, destinationObject=dest_prefix + suffix, body={} ).execute() copied_objs.append(dest_prefix + suffix) _wait_for_consistency(lambda: all(self._obj_exists(dest_bucket, obj) for obj in copied_objs)) else: self.client.objects().copy( sourceBucket=src_bucket, sourceObject=src_obj, destinationBucket=dest_bucket, destinationObject=dest_obj, body={} ).execute() _wait_for_consistency(lambda: self._obj_exists(dest_bucket, dest_obj)) def rename(self, *args, **kwargs): """ Alias for ``move()`` """ self.move(*args, **kwargs) def move(self, source_path, destination_path): """ Rename/move an object from one GCS location to another. """ self.copy(source_path, destination_path) self.remove(source_path) def listdir(self, path): """ Get an iterable with GCS folder contents. Iterable contains paths relative to queried path. """ bucket, obj = self._path_to_bucket_and_key(path) obj_prefix = self._add_path_delimiter(obj) if self._is_root(obj_prefix): obj_prefix = "" obj_prefix_len = len(obj_prefix) for it in self._list_iter(bucket, obj_prefix): yield self._add_path_delimiter(path) + it["name"][obj_prefix_len:] def list_wildcard(self, wildcard_path): """Yields full object URIs matching the given wildcard. Currently only the '*' wildcard after the last path delimiter is supported. (If we need "full" wildcard functionality we should bring in gsutil dependency with its https://github.com/GoogleCloudPlatform/gsutil/blob/master/gslib/wildcard_iterator.py...) """ path, wildcard_obj = wildcard_path.rsplit("/", 1) assert "*" not in path, "The '*' wildcard character is only supported after the last '/'" wildcard_parts = wildcard_obj.split("*") assert len(wildcard_parts) == 2, "Only one '*' wildcard is supported" for it in self.listdir(path): if ( it.startswith(path + "/" + wildcard_parts[0]) and it.endswith(wildcard_parts[1]) and len(it) >= len(path + "/" + wildcard_parts[0]) + len(wildcard_parts[1]) ): yield it @gcs_retry def download(self, path, chunksize=None, chunk_callback=lambda _: False): """Downloads the object contents to local file system. Optionally stops after the first chunk for which chunk_callback returns True. """ chunksize = chunksize or self.chunksize bucket, obj = self._path_to_bucket_and_key(path) with tempfile.NamedTemporaryFile(delete=False) as fp: # We can't return the tempfile reference because of a bug in python: http://bugs.python.org/issue18879 return_fp = _DeleteOnCloseFile(fp.name, "r") # Special case empty files because chunk-based downloading doesn't work. result = self.client.objects().get(bucket=bucket, object=obj).execute() if int(result["size"]) == 0: return return_fp request = self.client.objects().get_media(bucket=bucket, object=obj) downloader = http.MediaIoBaseDownload(fp, request, chunksize=chunksize) done = False while not done: _, done = downloader.next_chunk() if chunk_callback(fp): done = True return return_fp class _DeleteOnCloseFile(io.FileIO): def close(self): super(_DeleteOnCloseFile, self).close() try: os.remove(self.name) except OSError: # Catch a potential threading race condition and also allow this # method to be called multiple times. pass def readable(self): return True def writable(self): return False def seekable(self): return True class AtomicGCSFile(luigi.target.AtomicLocalFile): """ A GCS file that writes to a temp file and put to GCS on close. """ def __init__(self, path, gcs_client): self.gcs_client = gcs_client super(AtomicGCSFile, self).__init__(path) def move_to_final_destination(self): self.gcs_client.put(self.tmp_path, self.path) class GCSTarget(luigi.target.FileSystemTarget): fs = None def __init__(self, path, format=None, client=None): super(GCSTarget, self).__init__(path) if format is None: format = luigi.format.get_default_format() self.format = format self.fs = client or GCSClient() def open(self, mode="r"): if mode == "r": return self.format.pipe_reader(FileWrapper(io.BufferedReader(self.fs.download(self.path)))) elif mode == "w": return self.format.pipe_writer(AtomicGCSFile(self.path, self.fs)) else: raise ValueError("Unsupported open mode '{}'".format(mode)) class GCSFlagTarget(GCSTarget): """ Defines a target directory with a flag-file (defaults to `_SUCCESS`) used to signify job success. This checks for two things: * the path exists (just like the GCSTarget) * the _SUCCESS file exists within the directory. Because Hadoop outputs into a directory and not a single file, the path is assumed to be a directory. This is meant to be a handy alternative to AtomicGCSFile. The AtomicFile approach can be burdensome for GCS since there are no directories, per se. If we have 1,000,000 output files, then we have to rename 1,000,000 objects. """ fs = None def __init__(self, path, format=None, client=None, flag="_SUCCESS"): """ Initializes a GCSFlagTarget. :param path: the directory where the files are stored. :type path: str :param client: :type client: :param flag: :type flag: str """ if format is None: format = luigi.format.get_default_format() if path[-1] != "/": raise ValueError("GCSFlagTarget requires the path to be to a directory. It must end with a slash ( / ).") super(GCSFlagTarget, self).__init__(path, format=format, client=client) self.format = format self.fs = client or GCSClient() self.flag = flag def exists(self): flag_target = self.path + self.flag return self.fs.exists(flag_target) ================================================ FILE: luigi/contrib/hadoop.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ Run Hadoop Mapreduce jobs using Hadoop Streaming. To run a job, you need to subclass :py:class:`luigi.contrib.hadoop.JobTask` and implement a ``mapper`` and ``reducer`` methods. See :doc:`/example_top_artists` for an example of how to run a Hadoop job. """ import abc import datetime import glob import hashlib import logging import os import pickle import random import re import shutil import signal import subprocess import sys import tempfile import warnings from io import StringIO from itertools import groupby import luigi import luigi.contrib.gcs import luigi.contrib.hdfs import luigi.contrib.s3 from luigi import configuration from luigi.contrib import mrrunner from luigi.task import Config try: # See benchmark at https://gist.github.com/mvj3/02dca2bcc8b0ef1bbfb5 import ujson as json except ImportError: import json logger = logging.getLogger("luigi-interface") _attached_packages = [] TRACKING_RE = re.compile(r"(tracking url|the url to track the job):\s+(?P.+)$") class hadoop(Config): pool = luigi.OptionalParameter( default=None, description=("Hadoop pool so use for Hadoop tasks. To specify pools per tasks, see BaseHadoopJobTask.pool"), ) def attach(*packages): """ Attach a python package to hadoop map reduce tarballs to make those packages available on the hadoop cluster. """ _attached_packages.extend(packages) def dereference(f): if os.path.islink(f): # by joining with the dirname we are certain to get the absolute path return dereference(os.path.join(os.path.dirname(f), os.readlink(f))) else: return f def get_extra_files(extra_files): result = [] for f in extra_files: if isinstance(f, str): src, dst = f, os.path.basename(f) elif isinstance(f, tuple): src, dst = f else: raise Exception() if os.path.isdir(src): src_prefix = os.path.join(src, "") for base, dirs, files in os.walk(src): for f in files: f_src = os.path.join(base, f) f_src_stripped = f_src[len(src_prefix) :] f_dst = os.path.join(dst, f_src_stripped) result.append((f_src, f_dst)) else: result.append((src, dst)) return result def create_packages_archive(packages, filename): """ Create a tar archive which will contain the files for the packages listed in packages. """ import tarfile tar = tarfile.open(filename, "w") def add(src, dst): logger.debug("adding to tar: %s -> %s", src, dst) tar.add(src, dst) def add_files_for_package(sub_package_path, root_package_path, root_package_name): for root, dirs, files in os.walk(sub_package_path): if ".svn" in dirs: dirs.remove(".svn") for f in files: if not f.endswith(".pyc") and not f.startswith("."): add(dereference(root + "/" + f), root.replace(root_package_path, root_package_name) + "/" + f) for package in packages: # Put a submodule's entire package in the archive. This is the # magic that usually packages everything you need without # having to attach packages/modules explicitly if not getattr(package, "__path__", None) and "." in package.__name__: package = __import__(package.__name__.rpartition(".")[0], None, None, "non_empty") n = package.__name__.replace(".", "/") if getattr(package, "__path__", None): # TODO: (BUG) picking only the first path does not # properly deal with namespaced packages in different # directories p = package.__path__[0] if p.endswith(".egg") and os.path.isfile(p): raise "egg files not supported!!!" # Add the entire egg file # p = p[:p.find('.egg') + 4] # add(dereference(p), os.path.basename(p)) else: # include __init__ files from parent projects root = [] for parent in package.__name__.split(".")[0:-1]: root.append(parent) module_name = ".".join(root) directory = "/".join(root) add(dereference(__import__(module_name, None, None, "non_empty").__path__[0] + "/__init__.py"), directory + "/__init__.py") add_files_for_package(p, p, n) # include egg-info directories that are parallel: for egg_info_path in glob.glob(p + "*.egg-info"): logger.debug('Adding package metadata to archive for "%s" found at "%s"', package.__name__, egg_info_path) add_files_for_package(egg_info_path, p, n) else: f = package.__file__ if f.endswith("pyc"): f = f[:-3] + "py" if n.find(".") == -1: add(dereference(f), os.path.basename(f)) else: add(dereference(f), n + ".py") tar.close() def flatten(sequence): """ A simple generator which flattens a sequence. Only one level is flattened. .. code-block:: python (1, (2, 3), 4) -> (1, 2, 3, 4) """ for item in sequence: if hasattr(item, "__iter__") and not isinstance(item, str) and not isinstance(item, bytes): for i in item: yield i else: yield item class HadoopRunContext: def __init__(self): self.job_id = None self.application_id = None def __enter__(self): self.__old_signal = signal.getsignal(signal.SIGTERM) signal.signal(signal.SIGTERM, self.kill_job) return self def kill_job(self, captured_signal=None, stack_frame=None): if self.application_id: logger.info("Job interrupted, killing application %s" % self.application_id) subprocess.call(["yarn", "application", "-kill", self.application_id]) elif self.job_id: logger.info("Job interrupted, killing job %s", self.job_id) subprocess.call(["mapred", "job", "-kill", self.job_id]) if captured_signal is not None: # adding 128 gives the exit code corresponding to a signal sys.exit(128 + captured_signal) def __exit__(self, exc_type, exc_val, exc_tb): if exc_type is KeyboardInterrupt: self.kill_job() signal.signal(signal.SIGTERM, self.__old_signal) class HadoopJobError(RuntimeError): def __init__(self, message, out=None, err=None): super(HadoopJobError, self).__init__(message, out, err) self.message = message self.out = out self.err = err def __str__(self): return self.message def run_and_track_hadoop_job(arglist, tracking_url_callback=None, env=None): """ Runs the job by invoking the command from the given arglist. Finds tracking urls from the output and attempts to fetch errors using those urls if the job fails. Throws HadoopJobError with information about the error (including stdout and stderr from the process) on failure and returns normally otherwise. :param arglist: :param tracking_url_callback: :param env: :return: """ logger.info("%s", subprocess.list2cmdline(arglist)) def write_luigi_history(arglist, history): """ Writes history to a file in the job's output directory in JSON format. Currently just for tracking the job ID in a configuration where no history is stored in the output directory by Hadoop. """ history_filename = configuration.get_config().get("core", "history-filename", "") if history_filename and "-output" in arglist: output_dir = arglist[arglist.index("-output") + 1] f = luigi.contrib.hdfs.HdfsTarget(os.path.join(output_dir, history_filename)).open("w") f.write(json.dumps(history)) f.close() def track_process(arglist, tracking_url_callback, env=None): # Dump stdout to a temp file, poll stderr and log it temp_stdout = tempfile.TemporaryFile("w+t") proc = subprocess.Popen(arglist, stdout=temp_stdout, stderr=subprocess.PIPE, env=env, close_fds=True, universal_newlines=True) # We parse the output to try to find the tracking URL. # This URL is useful for fetching the logs of the job. tracking_url = None job_id = None application_id = None err_lines = [] with HadoopRunContext() as hadoop_context: while proc.poll() is None: err_line = proc.stderr.readline() err_lines.append(err_line) err_line = err_line.strip() if err_line: logger.info("%s", err_line) err_line = err_line.lower() tracking_url_match = TRACKING_RE.search(err_line) if tracking_url_match: tracking_url = tracking_url_match.group("url") try: tracking_url_callback(tracking_url) except Exception as e: logger.error("Error in tracking_url_callback, disabling! %s", e) def tracking_url_callback(x): return None if err_line.find("running job") != -1: # hadoop jar output job_id = err_line.split("running job: ")[-1] if err_line.find("submitted hadoop job:") != -1: # scalding output job_id = err_line.split("submitted hadoop job: ")[-1] if err_line.find("submitted application ") != -1: application_id = err_line.split("submitted application ")[-1] hadoop_context.job_id = job_id hadoop_context.application_id = application_id # Read the rest + stdout err = "".join(err_lines + [an_err_line for an_err_line in proc.stderr]) temp_stdout.seek(0) out = "".join(temp_stdout.readlines()) if proc.returncode == 0: write_luigi_history(arglist, {"job_id": job_id}) return (out, err) # Try to fetch error logs if possible message = "Streaming job failed with exit code %d. " % proc.returncode if not tracking_url: raise HadoopJobError(message + "Also, no tracking url found.", out, err) try: task_failures = fetch_task_failures(tracking_url) except Exception as e: raise HadoopJobError(message + "Additionally, an error occurred when fetching data from %s: %s" % (tracking_url, e), out, err) if not task_failures: raise HadoopJobError(message + "Also, could not fetch output from tasks.", out, err) else: raise HadoopJobError(message + "Output from tasks below:\n%s" % task_failures, out, err) if tracking_url_callback is None: def tracking_url_callback(x): return None return track_process(arglist, tracking_url_callback, env) def fetch_task_failures(tracking_url): """ Uses mechanize to fetch the actual task logs from the task tracker. This is highly opportunistic, and we might not succeed. So we set a low timeout and hope it works. If it does not, it's not the end of the world. TODO: Yarn has a REST API that we should probably use instead: http://hadoop.apache.org/docs/current/hadoop-yarn/hadoop-yarn-site/WebServicesIntro.html """ import mechanize timeout = 3.0 failures_url = tracking_url.replace("jobdetails.jsp", "jobfailures.jsp") + "&cause=failed" logger.debug("Fetching data from %s", failures_url) b = mechanize.Browser() b.open(failures_url, timeout=timeout) links = list(b.links(text_regex="Last 4KB")) # For some reason text_regex='All' doesn't work... no idea why links = random.sample(links, min(10, len(links))) # Fetch a random subset of all failed tasks, so not to be biased towards the early fails error_text = [] for link in links: task_url = link.url.replace("&start=-4097", "&start=-100000") # Increase the offset logger.debug("Fetching data from %s", task_url) b2 = mechanize.Browser() try: r = b2.open(task_url, timeout=timeout) data = r.read() except Exception as e: logger.debug("Error fetching data from %s: %s", task_url, e) continue # Try to get the hex-encoded traceback back from the output for exc in re.findall(r"luigi-exc-hex=[0-9a-f]+", data): error_text.append("---------- %s:" % task_url) error_text.append(exc.split("=")[-1].decode("hex")) return "\n".join(error_text) class JobRunner: run_job = NotImplemented class HadoopJobRunner(JobRunner): """ Takes care of uploading & executing a Hadoop job using Hadoop streaming. TODO: add code to support Elastic Mapreduce (using boto) and local execution. """ def __init__( self, streaming_jar, modules=None, streaming_args=None, libjars=None, libjars_in_hdfs=None, jobconfs=None, input_format=None, output_format=None, end_job_with_atomic_move_dir=True, archives=None, ): def get(x, default): return x is not None and x or default self.streaming_jar = streaming_jar self.modules = get(modules, []) self.streaming_args = get(streaming_args, []) self.libjars = get(libjars, []) self.libjars_in_hdfs = get(libjars_in_hdfs, []) self.archives = get(archives, []) self.jobconfs = get(jobconfs, {}) self.input_format = input_format self.output_format = output_format self.end_job_with_atomic_move_dir = end_job_with_atomic_move_dir self.tmp_dir = False def run_job(self, job, tracking_url_callback=None): if tracking_url_callback is not None: warnings.warn("tracking_url_callback argument is deprecated, task.set_tracking_url is used instead.", DeprecationWarning) packages = [luigi] + self.modules + job.extra_modules() + list(_attached_packages) # find the module containing the job packages.append(__import__(job.__module__, None, None, "dummy")) # find the path to out runner.py runner_path = mrrunner.__file__ # assume source is next to compiled if runner_path.endswith("pyc"): runner_path = runner_path[:-3] + "py" base_tmp_dir = configuration.get_config().get("core", "tmp-dir", None) if base_tmp_dir: warnings.warn( "The core.tmp-dir configuration item is" " deprecated, please use the TMPDIR" " environment variable if you wish" " to control where luigi.contrib.hadoop may" " create temporary files and directories." ) self.tmp_dir = os.path.join(base_tmp_dir, "hadoop_job_%016x" % random.getrandbits(64)) os.makedirs(self.tmp_dir) else: self.tmp_dir = tempfile.mkdtemp() logger.debug("Tmp dir: %s", self.tmp_dir) # build arguments config = configuration.get_config() python_executable = config.get("hadoop", "python-executable", "python") runner_arg = "mrrunner.pex" if job.package_binary is not None else "mrrunner.py" command = "{0} {1} {{step}}".format(python_executable, runner_arg) map_cmd = command.format(step="map") cmb_cmd = command.format(step="combiner") red_cmd = command.format(step="reduce") output_final = job.output().path # atomic output: replace output with a temporary work directory if self.end_job_with_atomic_move_dir: illegal_targets = (luigi.contrib.s3.S3FlagTarget, luigi.contrib.gcs.GCSFlagTarget) if isinstance(job.output(), illegal_targets): raise TypeError("end_job_with_atomic_move_dir is not supported for {}".format(illegal_targets)) output_hadoop = "{output}-temp-{time}".format(output=output_final, time=datetime.datetime.now().isoformat().replace(":", "-")) else: output_hadoop = output_final arglist = luigi.contrib.hdfs.load_hadoop_cmd() + ["jar", self.streaming_jar] # 'libjars' is a generic option, so place it first libjars = [libjar for libjar in self.libjars] for libjar in self.libjars_in_hdfs: run_cmd = luigi.contrib.hdfs.load_hadoop_cmd() + ["fs", "-get", libjar, self.tmp_dir] logger.debug(subprocess.list2cmdline(run_cmd)) subprocess.call(run_cmd) libjars.append(os.path.join(self.tmp_dir, os.path.basename(libjar))) if libjars: arglist += ["-libjars", ",".join(libjars)] # 'archives' is also a generic option archives = [] extra_archives = job.extra_archives() if self.archives: archives = self.archives if extra_archives: archives += extra_archives if archives: arglist += ["-archives", ",".join(archives)] # Add static files and directories extra_files = get_extra_files(job.extra_files()) files = [] for src, dst in extra_files: dst_tmp = "%s_%09d" % (dst.replace("/", "_"), random.randint(0, 999999999)) files += ["%s#%s" % (src, dst_tmp)] # -files doesn't support subdirectories, so we need to create the dst_tmp -> dst manually job.add_link(dst_tmp, dst) if files: arglist += ["-files", ",".join(files)] jobconfs = job.jobconfs() for k, v in self.jobconfs.items(): jobconfs.append("%s=%s" % (k, v)) for conf in jobconfs: arglist += ["-D", conf] arglist += self.streaming_args # Add additional non-generic per-job streaming args extra_streaming_args = job.extra_streaming_arguments() for arg, value in extra_streaming_args: if not arg.startswith("-"): # safety first arg = "-" + arg arglist += [arg, value] arglist += ["-mapper", map_cmd] if job.combiner != NotImplemented: arglist += ["-combiner", cmb_cmd] if job.reducer != NotImplemented: arglist += ["-reducer", red_cmd] packages_fn = "mrrunner.pex" if job.package_binary is not None else "packages.tar" files = [ runner_path if job.package_binary is None else None, os.path.join(self.tmp_dir, packages_fn), os.path.join(self.tmp_dir, "job-instance.pickle"), ] for f in filter(None, files): arglist += ["-file", f] if self.output_format: arglist += ["-outputformat", self.output_format] if self.input_format: arglist += ["-inputformat", self.input_format] allowed_input_targets = (luigi.contrib.hdfs.HdfsTarget, luigi.contrib.s3.S3Target, luigi.contrib.gcs.GCSTarget) for target in luigi.task.flatten(job.input_hadoop()): if not isinstance(target, allowed_input_targets): raise TypeError("target must one of: {}".format(allowed_input_targets)) arglist += ["-input", target.path] allowed_output_targets = (luigi.contrib.hdfs.HdfsTarget, luigi.contrib.s3.S3FlagTarget, luigi.contrib.gcs.GCSFlagTarget) if not isinstance(job.output(), allowed_output_targets): raise TypeError("output must be one of: {}".format(allowed_output_targets)) arglist += ["-output", output_hadoop] # submit job if job.package_binary is not None: shutil.copy(job.package_binary, os.path.join(self.tmp_dir, "mrrunner.pex")) else: create_packages_archive(packages, os.path.join(self.tmp_dir, "packages.tar")) job.dump(self.tmp_dir) run_and_track_hadoop_job(arglist, tracking_url_callback=job.set_tracking_url) if self.end_job_with_atomic_move_dir: luigi.contrib.hdfs.HdfsTarget(output_hadoop).move_dir(output_final) self.finish() def finish(self): # FIXME: check for isdir? if self.tmp_dir and os.path.exists(self.tmp_dir): logger.debug("Removing directory %s", self.tmp_dir) shutil.rmtree(self.tmp_dir) def __del__(self): self.finish() class DefaultHadoopJobRunner(HadoopJobRunner): """ The default job runner just reads from config and sets stuff. """ def __init__(self): config = configuration.get_config() streaming_jar = config.get("hadoop", "streaming-jar") super(DefaultHadoopJobRunner, self).__init__(streaming_jar=streaming_jar) # TODO: add more configurable options class LocalJobRunner(JobRunner): """ Will run the job locally. This is useful for debugging and also unit testing. Tries to mimic Hadoop Streaming. TODO: integrate with JobTask """ def __init__(self, samplelines=None): self.samplelines = samplelines def sample(self, input_stream, n, output): for i, line in enumerate(input_stream): if n is not None and i >= n: break output.write(line) def group(self, input_stream): output = StringIO() lines = [] for i, line in enumerate(input_stream): parts = line.rstrip("\n").split("\t") blob = hashlib.new("md5", str(i).encode("ascii"), usedforsecurity=False).hexdigest() # pseudo-random blob to make sure the input isn't sorted lines.append((parts[:-1], blob, line)) for _, _, line in sorted(lines): output.write(line) output.seek(0) return output def run_job(self, job): map_input = StringIO() for i in luigi.task.flatten(job.input_hadoop()): self.sample(i.open("r"), self.samplelines, map_input) map_input.seek(0) if job.reducer == NotImplemented: # Map only job; no combiner, no reducer map_output = job.output().open("w") job.run_mapper(map_input, map_output) map_output.close() return # run job now... map_output = StringIO() job.run_mapper(map_input, map_output) map_output.seek(0) if job.combiner == NotImplemented: reduce_input = self.group(map_output) else: combine_input = self.group(map_output) combine_output = StringIO() job.run_combiner(combine_input, combine_output) combine_output.seek(0) reduce_input = self.group(combine_output) reduce_output = job.output().open("w") job.run_reducer(reduce_input, reduce_output) reduce_output.close() class BaseHadoopJobTask(luigi.Task): pool = luigi.OptionalParameter(default=None, significant=False, positional=False) # This value can be set to change the default batching increment. Default is 1 for backwards compatibility. batch_counter_default = 1 final_mapper = NotImplemented final_combiner = NotImplemented final_reducer = NotImplemented mr_priority = NotImplemented package_binary = None _counter_dict = {} task_id = None def _get_pool(self): """Protected method""" if self.pool: return self.pool if hadoop().pool: return hadoop().pool @abc.abstractmethod def job_runner(self): pass def jobconfs(self): jcs = [] jcs.append("mapred.job.name=%s" % self) if self.mr_priority != NotImplemented: jcs.append("mapred.job.priority=%s" % self.mr_priority()) pool = self._get_pool() if pool is not None: # Supporting two schedulers: fair (default) and capacity using the same option scheduler_type = configuration.get_config().get("hadoop", "scheduler", "fair") if scheduler_type == "fair": jcs.append("mapred.fairscheduler.pool=%s" % pool) elif scheduler_type == "capacity": jcs.append("mapred.job.queue.name=%s" % pool) return jcs def init_local(self): """ Implement any work to setup any internal datastructure etc here. You can add extra input using the requires_local/input_local methods. Anything you set on the object will be pickled and available on the Hadoop nodes. """ pass def init_hadoop(self): pass # available formats are "python" and "json". data_interchange_format = "python" def run(self): # The best solution is to store them as lazy `cached_property`, but it # has extraneous dependency. And `property` is slow (need to be # calculated every time when called), so we save them as attributes # directly. self.serialize = DataInterchange[self.data_interchange_format]["serialize"] self.internal_serialize = DataInterchange[self.data_interchange_format]["internal_serialize"] self.deserialize = DataInterchange[self.data_interchange_format]["deserialize"] self.init_local() self.job_runner().run_job(self) def requires_local(self): """ Default impl - override this method if you need any local input to be accessible in init(). """ return [] def requires_hadoop(self): return self.requires() # default impl def input_local(self): return luigi.task.getpaths(self.requires_local()) def input_hadoop(self): return luigi.task.getpaths(self.requires_hadoop()) def deps(self): # Overrides the default implementation return luigi.task.flatten(self.requires_hadoop()) + luigi.task.flatten(self.requires_local()) def on_failure(self, exception): if isinstance(exception, HadoopJobError): return """Hadoop job failed with message: {message} stdout: {stdout} stderr: {stderr} """.format(message=exception.message, stdout=exception.out, stderr=exception.err) else: return super(BaseHadoopJobTask, self).on_failure(exception) DataInterchange = { "python": {"serialize": str, "internal_serialize": repr, "deserialize": eval}, "json": {"serialize": json.dumps, "internal_serialize": json.dumps, "deserialize": json.loads}, } class JobTask(BaseHadoopJobTask): jobconf_truncate = 20000 n_reduce_tasks = 25 reducer = NotImplemented def jobconfs(self): jcs = super(JobTask, self).jobconfs() if self.reducer == NotImplemented: jcs.append("mapred.reduce.tasks=0") else: jcs.append("mapred.reduce.tasks=%s" % self.n_reduce_tasks) if self.jobconf_truncate >= 0: jcs.append("stream.jobconf.truncate.limit=%i" % self.jobconf_truncate) return jcs def init_mapper(self): pass def init_combiner(self): pass def init_reducer(self): pass def _setup_remote(self): self._setup_links() def job_runner(self): # We recommend that you define a subclass, override this method and set up your own config """ Get the MapReduce runner for this job. If all outputs are HdfsTargets, the DefaultHadoopJobRunner will be used. Otherwise, the LocalJobRunner which streams all data through the local machine will be used (great for testing). """ outputs = luigi.task.flatten(self.output()) for output in outputs: if not isinstance(output, luigi.contrib.hdfs.HdfsTarget): warnings.warn("Job is using one or more non-HdfsTarget outputs" + " so it will be run in local mode") return LocalJobRunner() else: return DefaultHadoopJobRunner() def reader(self, input_stream): """ Reader is a method which iterates over input lines and outputs records. The default implementation yields one argument containing the line for each line in the input.""" for line in input_stream: yield (line,) def writer(self, outputs, stdout, stderr=sys.stderr): """ Writer format is a method which iterates over the output records from the reducer and formats them for output. The default implementation outputs tab separated items. """ for output in outputs: try: output = flatten(output) if self.data_interchange_format == "json": # Only dump one json string, and skip another one, maybe key or value. output = filter(lambda x: x, output) else: # JSON is already serialized, so we put `self.serialize` in a else statement. output = map(self.serialize, output) print("\t".join(output), file=stdout) except BaseException: print(output, file=stderr) raise def mapper(self, item): """ Re-define to process an input item (usually a line of input data). Defaults to identity mapper that sends all lines to the same reducer. """ yield None, item combiner = NotImplemented def incr_counter(self, *args, **kwargs): """ Increments a Hadoop counter. Since counters can be a bit slow to update, this batches the updates. """ threshold = kwargs.get("threshold", self.batch_counter_default) if len(args) == 2: # backwards compatibility with existing hadoop jobs group_name, count = args key = (group_name,) else: group, name, count = args key = (group, name) ct = self._counter_dict.get(key, 0) ct += count if ct >= threshold: new_arg = list(key) + [ct] self._incr_counter(*new_arg) ct = 0 self._counter_dict[key] = ct def _flush_batch_incr_counter(self): """ Increments any unflushed counter values. """ for key, count in self._counter_dict.items(): if count == 0: continue args = list(key) + [count] self._incr_counter(*args) self._counter_dict[key] = 0 def _incr_counter(self, *args): """ Increments a Hadoop counter. Note that this seems to be a bit slow, ~1 ms Don't overuse this function by updating very frequently. """ if len(args) == 2: # backwards compatibility with existing hadoop jobs group_name, count = args print("reporter:counter:%s,%s" % (group_name, count), file=sys.stderr) else: group, name, count = args print("reporter:counter:%s,%s,%s" % (group, name, count), file=sys.stderr) def extra_modules(self): return [] # can be overridden in subclass def extra_files(self): """ Can be overridden in subclass. Each element is either a string, or a pair of two strings (src, dst). * `src` can be a directory (in which case everything will be copied recursively). * `dst` can include subdirectories (foo/bar/baz.txt etc) Uses Hadoop's -files option so that the same file is reused across tasks. """ return [] def extra_streaming_arguments(self): """ Extra arguments to Hadoop command line. Return here a list of (parameter, value) tuples. """ return [] def extra_archives(self): """List of paths to archives""" return [] def add_link(self, src, dst): if not hasattr(self, "_links"): self._links = [] self._links.append((src, dst)) def _setup_links(self): if hasattr(self, "_links"): missing = [] for src, dst in self._links: d = os.path.dirname(dst) if d: try: os.makedirs(d) except OSError: pass if not os.path.exists(src): missing.append(src) continue if not os.path.exists(dst): # If the combiner runs, the file might already exist, # so no reason to create the link again os.link(src, dst) if missing: raise HadoopJobError("Missing files for distributed cache: " + ", ".join(missing)) def dump(self, directory=""): """ Dump instance to file. """ with self.no_unpicklable_properties(): file_name = os.path.join(directory, "job-instance.pickle") if self.__module__ == "__main__": d = pickle.dumps(self) module_name = os.path.basename(sys.argv[0]).rsplit(".", 1)[0] d = d.replace(b"(c__main__", "(c" + module_name) open(file_name, "wb").write(d) else: pickle.dump(self, open(file_name, "wb")) def _map_input(self, input_stream): """ Iterate over input and call the mapper for each item. If the job has a parser defined, the return values from the parser will be passed as arguments to the mapper. If the input is coded output from a previous run, the arguments will be splitted in key and value. """ for record in self.reader(input_stream): for output in self.mapper(*record): yield output if self.final_mapper != NotImplemented: for output in self.final_mapper(): yield output self._flush_batch_incr_counter() def _reduce_input(self, inputs, reducer, final=NotImplemented): """ Iterate over input, collect values with the same key, and call the reducer for each unique key. """ for key, values in groupby(inputs, key=lambda x: self.internal_serialize(x[0])): for output in reducer(self.deserialize(key), (v[1] for v in values)): yield output if final != NotImplemented: for output in final(): yield output self._flush_batch_incr_counter() def run_mapper(self, stdin=sys.stdin, stdout=sys.stdout): """ Run the mapper on the hadoop node. """ self.init_hadoop() self.init_mapper() outputs = self._map_input((line[:-1] for line in stdin)) if self.reducer == NotImplemented: self.writer(outputs, stdout) else: self.internal_writer(outputs, stdout) def run_reducer(self, stdin=sys.stdin, stdout=sys.stdout): """ Run the reducer on the hadoop node. """ self.init_hadoop() self.init_reducer() outputs = self._reduce_input(self.internal_reader((line[:-1] for line in stdin)), self.reducer, self.final_reducer) self.writer(outputs, stdout) def run_combiner(self, stdin=sys.stdin, stdout=sys.stdout): self.init_hadoop() self.init_combiner() outputs = self._reduce_input(self.internal_reader((line[:-1] for line in stdin)), self.combiner, self.final_combiner) self.internal_writer(outputs, stdout) def internal_reader(self, input_stream): """ Reader which uses python eval on each part of a tab separated string. Yields a tuple of python objects. """ for input_line in input_stream: yield list(map(self.deserialize, input_line.split("\t"))) def internal_writer(self, outputs, stdout): """ Writer which outputs the python repr for each item. """ for output in outputs: print("\t".join(map(self.internal_serialize, output)), file=stdout) ================================================ FILE: luigi/contrib/hadoop_jar.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ Provides functionality to run a Hadoop job using a Jar """ import logging import os import random import shlex import warnings import luigi.contrib.hadoop import luigi.contrib.hdfs logger = logging.getLogger("luigi-interface") def fix_paths(job): """ Coerce input arguments to use temporary files when used for output. Return a list of temporary file pairs (tmpfile, destination path) and a list of arguments. Converts each HdfsTarget to a string for the path. """ tmp_files = [] args = [] for x in job.args(): if isinstance(x, luigi.contrib.hdfs.HdfsTarget): # input/output if x.exists() or not job.atomic_output(): # input args.append(x.path) else: # output x_path_no_slash = x.path[:-1] if x.path[-1] == "/" else x.path y = luigi.contrib.hdfs.HdfsTarget(x_path_no_slash + "-luigi-tmp-%09d" % random.randrange(0, 10_000_000_000)) tmp_files.append((y, x_path_no_slash)) logger.info("Using temp path: %s for path %s", y.path, x.path) args.append(y.path) else: try: # hopefully the target has a path to use args.append(x.path) except AttributeError: # if there's no path then hope converting it to a string will work args.append(str(x)) return (tmp_files, args) class HadoopJarJobError(Exception): pass class HadoopJarJobRunner(luigi.contrib.hadoop.JobRunner): """ JobRunner for `hadoop jar` commands. Used to run a HadoopJarJobTask. """ def __init__(self): pass def run_job(self, job, tracking_url_callback=None): if tracking_url_callback is not None: warnings.warn("tracking_url_callback argument is deprecated, task.set_tracking_url is used instead.", DeprecationWarning) # TODO(jcrobak): libjars, files, etc. Can refactor out of # hadoop.HadoopJobRunner if not job.jar(): raise HadoopJarJobError("Jar not defined") hadoop_arglist = luigi.contrib.hdfs.load_hadoop_cmd() + ["jar", job.jar()] if job.main(): hadoop_arglist.append(job.main()) jobconfs = job.jobconfs() for jc in jobconfs: hadoop_arglist += ["-D" + jc] (tmp_files, job_args) = fix_paths(job) hadoop_arglist += job_args ssh_config = job.ssh() if ssh_config: host = ssh_config.get("host", None) key_file = ssh_config.get("key_file", None) username = ssh_config.get("username", None) if not host or not key_file or not username: raise HadoopJarJobError("missing some config for HadoopRemoteJarJobRunner") arglist = ["ssh", "-i", key_file, "-o", "BatchMode=yes"] # no password prompts etc if ssh_config.get("no_host_key_check", False): arglist += ["-o", "UserKnownHostsFile=/dev/null", "-o", "StrictHostKeyChecking=no"] arglist.append("{}@{}".format(username, host)) hadoop_arglist = [shlex.quote(arg) for arg in hadoop_arglist] arglist.append(" ".join(hadoop_arglist)) else: if not os.path.exists(job.jar()): logger.error("Can't find jar: %s, full path %s", job.jar(), os.path.abspath(job.jar())) raise HadoopJarJobError("job jar does not exist") arglist = hadoop_arglist luigi.contrib.hadoop.run_and_track_hadoop_job(arglist, job.set_tracking_url) for a, b in tmp_files: a.move(b) class HadoopJarJobTask(luigi.contrib.hadoop.BaseHadoopJobTask): """ A job task for `hadoop jar` commands that define a jar and (optional) main method. """ def jar(self): """ Path to the jar for this Hadoop Job. """ return None def main(self): """ optional main method for this Hadoop Job. """ return None def job_runner(self): # We recommend that you define a subclass, override this method and set up your own config return HadoopJarJobRunner() def atomic_output(self): """ If True, then rewrite output arguments to be temp locations and atomically move them into place after the job finishes. """ return True def ssh(self): """ Set this to run hadoop command remotely via ssh. It needs to be a dict that looks like {"host": "myhost", "key_file": None, "username": None, ["no_host_key_check": False]} """ return None def args(self): """ Returns an array of args to pass to the job (after hadoop jar
    ). """ return [] ================================================ FILE: luigi/contrib/hdfs/__init__.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ Provides access to HDFS using the :py:class:`HdfsTarget`, a subclass of :py:class:`~luigi.target.Target`. You can configure what client by setting the "client" config under the "hdfs" section in the configuration, or using the ``--hdfs-client`` command line option. "hadoopcli" is the slowest, but should work out of the box. Since the hdfs functionality is quite big in luigi, it's split into smaller files under ``luigi/contrib/hdfs/*.py``. But for the sake of convenience and API stability, everything is reexported under :py:mod:`luigi.contrib.hdfs`. """ # imports from luigi.contrib.hdfs import clients as hdfs_clients from luigi.contrib.hdfs import config as hdfs_config from luigi.contrib.hdfs import error as hdfs_error from luigi.contrib.hdfs import format as hdfs_format from luigi.contrib.hdfs import hadoopcli_clients as hdfs_hadoopcli_clients from luigi.contrib.hdfs import target as hdfs_target from luigi.contrib.hdfs import webhdfs_client as hdfs_webhdfs_client # config.py hdfs = hdfs_config.hdfs load_hadoop_cmd = hdfs_config.load_hadoop_cmd get_configured_hadoop_version = hdfs_config.get_configured_hadoop_version get_configured_hdfs_client = hdfs_config.get_configured_hdfs_client tmppath = hdfs_config.tmppath # clients HDFSCliError = hdfs_error.HDFSCliError call_check = hdfs_hadoopcli_clients.HdfsClient.call_check HdfsClient = hdfs_hadoopcli_clients.HdfsClient WebHdfsClient = hdfs_webhdfs_client.WebHdfsClient HdfsClientCdh3 = hdfs_hadoopcli_clients.HdfsClientCdh3 HdfsClientApache1 = hdfs_hadoopcli_clients.HdfsClientApache1 create_hadoopcli_client = hdfs_hadoopcli_clients.create_hadoopcli_client get_autoconfig_client = hdfs_clients.get_autoconfig_client exists = hdfs_clients.exists rename = hdfs_clients.rename remove = hdfs_clients.remove mkdir = hdfs_clients.mkdir listdir = hdfs_clients.listdir # format.py HdfsReadPipe = hdfs_format.HdfsReadPipe HdfsAtomicWritePipe = hdfs_format.HdfsAtomicWritePipe HdfsAtomicWriteDirPipe = hdfs_format.HdfsAtomicWriteDirPipe PlainFormat = hdfs_format.PlainFormat PlainDirFormat = hdfs_format.PlainDirFormat Plain = hdfs_format.Plain PlainDir = hdfs_format.PlainDir CompatibleHdfsFormat = hdfs_format.CompatibleHdfsFormat # target.py HdfsTarget = hdfs_target.HdfsTarget HdfsFlagTarget = hdfs_target.HdfsFlagTarget ================================================ FILE: luigi/contrib/hdfs/abstract_client.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ Module containing abstract class about hdfs clients. """ import abc import luigi.target class HdfsFileSystem(luigi.target.FileSystem, metaclass=abc.ABCMeta): """ This client uses Apache 2.x syntax for file system commands, which also matched CDH4. """ def rename(self, path, dest): """ Rename or move a file. In hdfs land, "mv" is often called rename. So we add an alias for ``move()`` called ``rename()``. This is also to keep backward compatibility since ``move()`` became standardized in luigi's filesystem interface. """ return self.move(path, dest) def rename_dont_move(self, path, dest): """ Override this method with an implementation that uses rename2, which is a rename operation that never moves. rename2 - https://github.com/apache/hadoop/blob/ae91b13/hadoop-hdfs-project/hadoop-hdfs/src/main/java/org/apache/hadoop/hdfs/protocol/ClientProtocol.java (lines 483-523) """ # We only override this method to be able to provide a more specific # docstring. return super(HdfsFileSystem, self).rename_dont_move(path, dest) @abc.abstractmethod def remove(self, path, recursive=True, skip_trash=False): pass @abc.abstractmethod def chmod(self, path, permissions, recursive=False): pass @abc.abstractmethod def chown(self, path, owner, group, recursive=False): pass @abc.abstractmethod def count(self, path): """ Count contents in a directory """ pass @abc.abstractmethod def copy(self, path, destination): pass @abc.abstractmethod def put(self, local_path, destination): pass @abc.abstractmethod def get(self, path, local_destination): pass @abc.abstractmethod def mkdir(self, path, parents=True, raise_if_exists=False): pass @abc.abstractmethod def listdir(self, path, ignore_directories=False, ignore_files=False, include_size=False, include_type=False, include_time=False, recursive=False): pass @abc.abstractmethod def touchz(self, path): pass ================================================ FILE: luigi/contrib/hdfs/clients.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ The implementations of the hdfs clients. """ import logging import threading from luigi.contrib.hdfs import config as hdfs_config from luigi.contrib.hdfs import hadoopcli_clients as hdfs_hadoopcli_clients from luigi.contrib.hdfs import webhdfs_client as hdfs_webhdfs_client logger = logging.getLogger("luigi-interface") _AUTOCONFIG_CLIENT = threading.local() def get_autoconfig_client(client_cache=_AUTOCONFIG_CLIENT): """ Creates the client as specified in the `luigi.cfg` configuration. """ try: return client_cache.client except AttributeError: configured_client = hdfs_config.get_configured_hdfs_client() if configured_client == "webhdfs": client_cache.client = hdfs_webhdfs_client.WebHdfsClient() elif configured_client == "hadoopcli": client_cache.client = hdfs_hadoopcli_clients.create_hadoopcli_client() else: raise Exception("Unknown hdfs client " + configured_client) return client_cache.client def _with_ac(method_name): def result(*args, **kwargs): return getattr(get_autoconfig_client(), method_name)(*args, **kwargs) return result exists = _with_ac("exists") rename = _with_ac("rename") remove = _with_ac("remove") mkdir = _with_ac("mkdir") listdir = _with_ac("listdir") ================================================ FILE: luigi/contrib/hdfs/config.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ You can configure what client by setting the "client" config under the "hdfs" section in the configuration, or using the ``--hdfs-client`` command line option. "hadoopcli" is the slowest, but should work out of the box. """ import getpass import os import random from urllib.parse import urlparse, urlunparse import luigi import luigi.configuration class hdfs(luigi.Config): client_version = luigi.IntParameter(default=None) namenode_host = luigi.OptionalParameter(default=None) namenode_port = luigi.IntParameter(default=None) client = luigi.Parameter(default="hadoopcli") tmp_dir = luigi.OptionalParameter( default=None, config_path=dict(section="core", name="hdfs-tmp-dir"), ) class hadoopcli(luigi.Config): command = luigi.Parameter( default="hadoop", config_path=dict(section="hadoop", name="command"), description='The hadoop command, will run split() on it, so you can pass something like "hadoop --param"', ) version = luigi.Parameter(default="cdh4", config_path=dict(section="hadoop", name="version"), description="Can also be cdh3 or apache1") def load_hadoop_cmd(): return hadoopcli().command.split() def get_configured_hadoop_version(): """ CDH4 (hadoop 2+) has a slightly different syntax for interacting with hdfs via the command line. The default version is CDH4, but one can override this setting with "cdh3" or "apache1" in the hadoop section of the config in order to use the old syntax. """ return hadoopcli().version.lower() def get_configured_hdfs_client(): """ This is a helper that fetches the configuration value for 'client' in the [hdfs] section. It will return the client that retains backwards compatibility when 'client' isn't configured. """ return hdfs().client def tmppath(path=None, include_unix_username=True): """ @param path: target path for which it is needed to generate temporary location @type path: str @type include_unix_username: bool @rtype: str Note that include_unix_username might work on windows too. """ addon = "luigitemp-%09d" % random.randrange(0, 10_000_000_000) temp_dir = "/tmp" # default tmp dir if none is specified in config # 1. Figure out to which temporary directory to place configured_hdfs_tmp_dir = hdfs().tmp_dir if configured_hdfs_tmp_dir is not None: # config is superior base_dir = configured_hdfs_tmp_dir elif path is not None: # need to copy correct schema and network location parsed = urlparse(path) base_dir = urlunparse((parsed.scheme, parsed.netloc, temp_dir, "", "", "")) else: # just system temporary directory base_dir = temp_dir # 2. Figure out what to place if path is not None: if path.startswith(temp_dir + "/"): # Not 100%, but some protection from directories like /tmp/tmp/file subdir = path[len(temp_dir) :] else: # Protection from /tmp/hdfs:/dir/file parsed = urlparse(path) subdir = parsed.path subdir = subdir.lstrip("/") + "-" else: # just return any random temporary location subdir = "" if include_unix_username: subdir = os.path.join(getpass.getuser(), subdir) return os.path.join(base_dir, subdir + addon) ================================================ FILE: luigi/contrib/hdfs/error.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ The implementations of the hdfs clients. """ class HDFSCliError(Exception): def __init__(self, command, returncode, stdout, stderr): self.returncode = returncode self.stdout = stdout self.stderr = stderr msg = ("Command %r failed [exit code %d]\n---stdout---\n%s\n---stderr---\n%s------------") % (command, returncode, stdout, stderr) super(HDFSCliError, self).__init__(msg) ================================================ FILE: luigi/contrib/hdfs/format.py ================================================ import logging import os import luigi.format from luigi.contrib.hdfs import config as hdfs_config from luigi.contrib.hdfs.clients import exists, listdir, mkdir, remove, rename from luigi.contrib.hdfs.config import load_hadoop_cmd from luigi.contrib.hdfs.error import HDFSCliError logger = logging.getLogger("luigi-interface") class HdfsAtomicWriteError(IOError): pass class HdfsReadPipe(luigi.format.InputPipeProcessWrapper): def __init__(self, path): super(HdfsReadPipe, self).__init__(load_hadoop_cmd() + ["fs", "-cat", path]) class HdfsAtomicWritePipe(luigi.format.OutputPipeProcessWrapper): """ File like object for writing to HDFS The referenced file is first written to a temporary location and then renamed to final location on close(). If close() isn't called the temporary file will be cleaned up when this object is garbage collected TODO: if this is buggy, change it so it first writes to a local temporary file and then uploads it on completion """ def __init__(self, path): self.path = path self.tmppath = hdfs_config.tmppath(self.path) parent_dir = os.path.dirname(self.tmppath) mkdir(parent_dir, parents=True, raise_if_exists=False) super(HdfsAtomicWritePipe, self).__init__(load_hadoop_cmd() + ["fs", "-put", "-", self.tmppath]) def abort(self): logger.info("Aborting %s('%s'). Removing temporary file '%s'", self.__class__.__name__, self.path, self.tmppath) super(HdfsAtomicWritePipe, self).abort() remove(self.tmppath, skip_trash=True) def close(self): super(HdfsAtomicWritePipe, self).close() try: if exists(self.path): remove(self.path) except Exception as ex: if isinstance(ex, HDFSCliError) or ex.args[0].contains("FileNotFoundException"): pass else: raise ex if not all(result["result"] for result in rename(self.tmppath, self.path) or []): raise HdfsAtomicWriteError("Atomic write to {} failed".format(self.path)) class HdfsAtomicWriteDirPipe(luigi.format.OutputPipeProcessWrapper): """ Writes a data file to a directory at . """ def __init__(self, path, data_extension=""): self.path = path self.tmppath = hdfs_config.tmppath(self.path) self.datapath = self.tmppath + ("/data%s" % data_extension) super(HdfsAtomicWriteDirPipe, self).__init__(load_hadoop_cmd() + ["fs", "-put", "-", self.datapath]) def abort(self): logger.info("Aborting %s('%s'). Removing temporary dir '%s'", self.__class__.__name__, self.path, self.tmppath) super(HdfsAtomicWriteDirPipe, self).abort() remove(self.tmppath, skip_trash=True) def close(self): super(HdfsAtomicWriteDirPipe, self).close() try: if exists(self.path): remove(self.path) except Exception as ex: if isinstance(ex, HDFSCliError) or ex.args[0].contains("FileNotFoundException"): pass else: raise ex # it's unlikely to fail in this way but better safe than sorry if not all(result["result"] for result in rename(self.tmppath, self.path) or []): raise HdfsAtomicWriteError("Atomic write to {} failed".format(self.path)) if os.path.basename(self.tmppath) in map(os.path.basename, listdir(self.path)): remove(self.path) raise HdfsAtomicWriteError("Atomic write to {} failed".format(self.path)) class PlainFormat(luigi.format.Format): input = "bytes" output = "hdfs" def hdfs_writer(self, path): return self.pipe_writer(path) def hdfs_reader(self, path): return self.pipe_reader(path) def pipe_reader(self, path): return HdfsReadPipe(path) def pipe_writer(self, output_pipe): return HdfsAtomicWritePipe(output_pipe) class PlainDirFormat(luigi.format.Format): input = "bytes" output = "hdfs" def hdfs_writer(self, path): return self.pipe_writer(path) def hdfs_reader(self, path): return self.pipe_reader(path) def pipe_reader(self, path): # exclude underscore-prefixedfiles/folders (created by MapReduce) return HdfsReadPipe("%s/[^_]*" % path) def pipe_writer(self, path): return HdfsAtomicWriteDirPipe(path) Plain = PlainFormat() PlainDir = PlainDirFormat() class CompatibleHdfsFormat(luigi.format.Format): output = "hdfs" def __init__(self, writer, reader, input=None): if input is not None: self.input = input self.reader = reader self.writer = writer def pipe_writer(self, output): return self.writer(output) def pipe_reader(self, input): return self.reader(input) def hdfs_writer(self, output): return self.writer(output) def hdfs_reader(self, input): return self.reader(input) # __getstate__/__setstate__ needed for pickling, because self.reader and # self.writer may be unpickleable instance methods of another format class. # This was mainly to support pickling of standard HdfsTarget instances. def __getstate__(self): d = self.__dict__.copy() for attr in ("reader", "writer"): method = getattr(self, attr) try: # if instance method, pickle instance and method name d[attr] = method.__self__, method.__func__.__name__ except AttributeError: pass # not an instance method return d def __setstate__(self, d): self.__dict__ = d for attr in ("reader", "writer"): try: method_self, method_name = d[attr] except ValueError: continue method = getattr(method_self, method_name) setattr(self, attr, method) ================================================ FILE: luigi/contrib/hdfs/hadoopcli_clients.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ The implementations of the hdfs clients. """ import datetime import logging import os import re import subprocess import warnings from luigi.contrib.hdfs import abstract_client as hdfs_abstract_client from luigi.contrib.hdfs import config as hdfs_config from luigi.contrib.hdfs import error as hdfs_error from luigi.contrib.hdfs.config import load_hadoop_cmd from luigi.target import FileAlreadyExists logger = logging.getLogger("luigi-interface") def create_hadoopcli_client(): """ Given that we want one of the hadoop cli clients, this one will return the right one. """ version = hdfs_config.get_configured_hadoop_version() if version == "cdh4": return HdfsClient() elif version == "cdh3": return HdfsClientCdh3() elif version == "apache1": return HdfsClientApache1() else: raise ValueError("Error: Unknown version specified in Hadoop versionconfiguration parameter") class HdfsClient(hdfs_abstract_client.HdfsFileSystem): """ This client uses Apache 2.x syntax for file system commands, which also matched CDH4. """ recursive_listdir_cmd = ["-ls", "-R"] @staticmethod def call_check(command): p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, close_fds=True, universal_newlines=True) stdout, stderr = p.communicate() if p.returncode != 0: raise hdfs_error.HDFSCliError(command, p.returncode, stdout, stderr) return stdout def exists(self, path): """ Use ``hadoop fs -stat`` to check file existence. """ cmd = load_hadoop_cmd() + ["fs", "-stat", path] logger.debug("Running file existence check: %s", subprocess.list2cmdline(cmd)) p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, close_fds=True, universal_newlines=True) stdout, stderr = p.communicate() if p.returncode == 0: return True else: not_found_pattern = "^.*No such file or directory$" not_found_re = re.compile(not_found_pattern) for line in stderr.split("\n"): if not_found_re.match(line): return False raise hdfs_error.HDFSCliError(cmd, p.returncode, stdout, stderr) def move(self, path, dest): parent_dir = os.path.dirname(dest) if parent_dir != "" and not self.exists(parent_dir): self.mkdir(parent_dir) if not isinstance(path, (list, tuple)): path = [path] else: warnings.warn("Renaming multiple files at once is not atomic.", stacklevel=2) self.call_check(load_hadoop_cmd() + ["fs", "-mv"] + path + [dest]) def remove(self, path, recursive=True, skip_trash=False): if recursive: cmd = load_hadoop_cmd() + ["fs", "-rm", "-r"] else: cmd = load_hadoop_cmd() + ["fs", "-rm"] if skip_trash: cmd = cmd + ["-skipTrash"] cmd = cmd + [path] self.call_check(cmd) def chmod(self, path, permissions, recursive=False): if recursive: cmd = load_hadoop_cmd() + ["fs", "-chmod", "-R", permissions, path] else: cmd = load_hadoop_cmd() + ["fs", "-chmod", permissions, path] self.call_check(cmd) def chown(self, path, owner, group, recursive=False): if owner is None: owner = "" if group is None: group = "" ownership = "%s:%s" % (owner, group) if recursive: cmd = load_hadoop_cmd() + ["fs", "-chown", "-R", ownership, path] else: cmd = load_hadoop_cmd() + ["fs", "-chown", ownership, path] self.call_check(cmd) def count(self, path): cmd = load_hadoop_cmd() + ["fs", "-count", path] stdout = self.call_check(cmd) lines = stdout.split("\n") for line in stdout.split("\n"): if line.startswith("OpenJDK 64-Bit Server VM warning") or line.startswith("It's highly recommended") or not line: lines.pop(lines.index(line)) else: (dir_count, file_count, content_size, ppath) = stdout.split() results = {"content_size": content_size, "dir_count": dir_count, "file_count": file_count} return results def copy(self, path, destination): self.call_check(load_hadoop_cmd() + ["fs", "-cp", path, destination]) def put(self, local_path, destination): self.call_check(load_hadoop_cmd() + ["fs", "-put", local_path, destination]) def get(self, path, local_destination): self.call_check(load_hadoop_cmd() + ["fs", "-get", path, local_destination]) def getmerge(self, path, local_destination, new_line=False): if new_line: cmd = load_hadoop_cmd() + ["fs", "-getmerge", "-nl", path, local_destination] else: cmd = load_hadoop_cmd() + ["fs", "-getmerge", path, local_destination] self.call_check(cmd) def mkdir(self, path, parents=True, raise_if_exists=False): if parents and raise_if_exists: raise NotImplementedError("HdfsClient.mkdir can't raise with -p") try: cmd = load_hadoop_cmd() + ["fs", "-mkdir"] + (["-p"] if parents else []) + [path] self.call_check(cmd) except hdfs_error.HDFSCliError as ex: if "File exists" in ex.stderr: if raise_if_exists: raise FileAlreadyExists(ex.stderr) else: raise def listdir(self, path, ignore_directories=False, ignore_files=False, include_size=False, include_type=False, include_time=False, recursive=False): if not path: path = "." # default to current/home catalog if recursive: cmd = load_hadoop_cmd() + ["fs"] + self.recursive_listdir_cmd + [path] else: cmd = load_hadoop_cmd() + ["fs", "-ls", path] lines = self.call_check(cmd).split("\n") for line in lines: if not line: continue elif line.startswith("OpenJDK 64-Bit Server VM warning") or line.startswith("It's highly recommended") or line.startswith("Found"): continue # "hadoop fs -ls" outputs "Found %d items" as its first line elif ignore_directories and line[0] == "d": continue elif ignore_files and line[0] == "-": continue data = line.split(" ") file = data[-1] size = int(data[-4]) line_type = line[0] extra_data = () if include_size: extra_data += (size,) if include_type: extra_data += (line_type,) if include_time: time_str = "%sT%s" % (data[-3], data[-2]) modification_time = datetime.datetime.strptime(time_str, "%Y-%m-%dT%H:%M") extra_data += (modification_time,) if len(extra_data) > 0: yield (file,) + extra_data else: yield file def touchz(self, path): self.call_check(load_hadoop_cmd() + ["fs", "-touchz", path]) class HdfsClientCdh3(HdfsClient): """ This client uses CDH3 syntax for file system commands. """ def mkdir(self, path, parents=True, raise_if_exists=False): """ No explicit -p switch, this version of Hadoop always creates parent directories. """ try: self.call_check(load_hadoop_cmd() + ["fs", "-mkdir", path]) except hdfs_error.HDFSCliError as ex: if "File exists" in ex.stderr: if raise_if_exists: raise FileAlreadyExists(ex.stderr) else: raise def remove(self, path, recursive=True, skip_trash=False): if recursive: cmd = load_hadoop_cmd() + ["fs", "-rmr"] else: cmd = load_hadoop_cmd() + ["fs", "-rm"] if skip_trash: cmd = cmd + ["-skipTrash"] cmd = cmd + [path] self.call_check(cmd) class HdfsClientApache1(HdfsClientCdh3): """ This client uses Apache 1.x syntax for file system commands, which are similar to CDH3 except for the file existence check. """ recursive_listdir_cmd = ["-lsr"] def exists(self, path): cmd = load_hadoop_cmd() + ["fs", "-test", "-e", path] p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, close_fds=True) stdout, stderr = p.communicate() if p.returncode == 0: return True elif p.returncode == 1: return False else: raise hdfs_error.HDFSCliError(cmd, p.returncode, stdout, stderr) ================================================ FILE: luigi/contrib/hdfs/target.py ================================================ # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ Provides access to HDFS using the :py:class:`HdfsTarget`, a subclass of :py:class:`~luigi.target.Target`. """ import random import warnings from urllib import parse as urlparse import luigi from luigi.contrib.hdfs import clients as hdfs_clients from luigi.contrib.hdfs import format as hdfs_format from luigi.contrib.hdfs.config import tmppath from luigi.target import FileSystemTarget class HdfsTarget(FileSystemTarget): def __init__(self, path=None, format=None, is_tmp=False, fs=None): if path is None: assert is_tmp path = tmppath() super(HdfsTarget, self).__init__(path) if format is None: format = luigi.format.get_default_format() >> hdfs_format.Plain old_format = (hasattr(format, "hdfs_writer") or hasattr(format, "hdfs_reader")) and not hasattr(format, "output") if not old_format and getattr(format, "output", "") != "hdfs": format = format >> hdfs_format.Plain if old_format: warnings.warn( "hdfs_writer and hdfs_reader method for format is deprecated,specify the property output of your format as 'hdfs' instead", DeprecationWarning, stacklevel=2, ) if hasattr(format, "hdfs_writer"): format_writer = format.hdfs_writer else: w_format = format >> hdfs_format.Plain format_writer = w_format.pipe_writer if hasattr(format, "hdfs_reader"): format_reader = format.hdfs_reader else: r_format = format >> hdfs_format.Plain format_reader = r_format.pipe_reader format = hdfs_format.CompatibleHdfsFormat( format_writer, format_reader, ) else: format = hdfs_format.CompatibleHdfsFormat( format.pipe_writer, format.pipe_reader, getattr(format, "input", None), ) self.format = format self.is_tmp = is_tmp (scheme, netloc, path, query, fragment) = urlparse.urlsplit(path) if ":" in path: raise ValueError("colon is not allowed in hdfs filenames") self._fs = fs or hdfs_clients.get_autoconfig_client() def __del__(self): # TODO: not sure is_tmp belongs in Targets construction arguments if self.is_tmp and self.exists(): self.remove(skip_trash=True) @property def fs(self): return self._fs def glob_exists(self, expected_files): ls = list(self.fs.listdir(self.path)) if len(ls) == expected_files: return True return False def open(self, mode="r"): if mode not in ("r", "w"): raise ValueError("Unsupported open mode '%s'" % mode) if mode == "r": return self.format.pipe_reader(self.path) else: return self.format.pipe_writer(self.path) def remove(self, skip_trash=False): self.fs.remove(self.path, skip_trash=skip_trash) def rename(self, path, raise_if_exists=False): """ Does not change self.path. Unlike ``move_dir()``, ``rename()`` might cause nested directories. See spotify/luigi#522 """ if isinstance(path, HdfsTarget): path = path.path if raise_if_exists and self.fs.exists(path): raise RuntimeError("Destination exists: %s" % path) self.fs.rename(self.path, path) def move(self, path, raise_if_exists=False): """ Alias for ``rename()`` """ self.rename(path, raise_if_exists=raise_if_exists) def move_dir(self, path): """ Move using :py:class:`~luigi.contrib.hdfs.abstract_client.HdfsFileSystem.rename_dont_move` New since after luigi v2.1: Does not change self.path One could argue that the implementation should use the mkdir+raise_if_exists approach, but we at Spotify have had more trouble with that over just using plain mv. See spotify/luigi#557 """ self.fs.rename_dont_move(self.path, path) def copy(self, dst_dir): """ Copy to destination directory. """ self.fs.copy(self.path, dst_dir) def is_writable(self): """ Currently only works with hadoopcli """ if "/" in self.path: # example path: /log/ap/2013-01-17/00 parts = self.path.split("/") # start with the full path and then up the tree until we can check length = len(parts) for part in range(length): path = "/".join(parts[0 : length - part]) + "/" if self.fs.exists(path): # if the path exists and we can write there, great! if self._is_writable(path): return True # if it exists and we can't =( sad panda else: return False # We went through all parts of the path and we still couldn't find # one that exists. return False def _is_writable(self, path): test_path = path + ".test_write_access-%09d" % random.randrange(10_000_000_000) try: self.fs.touchz(test_path) self.fs.remove(test_path, recursive=False) return True except hdfs_clients.HDFSCliError: return False class HdfsFlagTarget(HdfsTarget): """ Defines a target directory with a flag-file (defaults to `_SUCCESS`) used to signify job success. This checks for two things: * the path exists (just like the HdfsTarget) * the _SUCCESS file exists within the directory. Because Hadoop outputs into a directory and not a single file, the path is assumed to be a directory. """ def __init__(self, path, format=None, client=None, flag="_SUCCESS"): """ Initializes a HdfsFlagTarget. :param path: the directory where the files are stored. :type path: str :param client: :type client: :param flag: :type flag: str """ if path[-1] != "/": raise ValueError("HdfsFlagTarget requires the path to be to a directory. It must end with a slash ( / ).") super(HdfsFlagTarget, self).__init__(path, format, client) self.flag = flag def exists(self): hadoopSemaphore = self.path + self.flag return self.fs.exists(hadoopSemaphore) ================================================ FILE: luigi/contrib/hdfs/webhdfs_client.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2015 VNG Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ A luigi file system client that wraps around the hdfs-library (a webhdfs client) Note. This wrapper client is not feature complete yet. As with most software the authors only implement the features they need. If you need to wrap more of the file system operations, please do and contribute back. """ import logging import os import warnings import luigi.contrib.target from luigi.contrib.hdfs import abstract_client as hdfs_abstract_client from luigi.contrib.hdfs import config as hdfs_config logger = logging.getLogger("luigi-interface") class webhdfs(luigi.Config): port = luigi.IntParameter(default=50070, description="Port for webhdfs") user = luigi.Parameter(default="", description="Defaults to $USER envvar", config_path=dict(section="hdfs", name="user")) client_type = luigi.ChoiceParameter(var_type=str, choices=["insecure", "kerberos"], default="insecure", description="Type of hdfs client to use.") class WebHdfsClient(hdfs_abstract_client.HdfsFileSystem): """ A webhdfs that tries to confirm to luigis interface for file existence. The library is using `this api `__. """ def __init__(self, host=None, port=None, user=None, client_type=None): self.host = host or hdfs_config.hdfs().namenode_host self.port = port or webhdfs().port self.user = user or webhdfs().user or os.environ["USER"] self.client_type = client_type or webhdfs().client_type @property def url(self): # the hdfs package allows it to specify multiple namenodes by passing a string containing # multiple namenodes separated by ';' hosts = self.host.split(";") urls = ["http://" + host + ":" + str(self.port) for host in hosts] return ";".join(urls) @property def client(self): # A naive benchmark showed that 1000 existence checks took 2.5 secs # when not recreating the client, and 4.0 secs when recreating it. So # not urgent to memoize it. Note that it *might* be issues with process # forking and whatnot (as the one in the snakebite client) if we # memoize it too trivially. if self.client_type == "kerberos": from hdfs.ext.kerberos import KerberosClient return KerberosClient(url=self.url) else: import hdfs return hdfs.InsecureClient(url=self.url, user=self.user) def walk(self, path, depth=1): return self.client.walk(path, depth=depth) def exists(self, path): """ Returns true if the path exists and false otherwise. """ import hdfs try: self.client.status(path) return True except hdfs.util.HdfsError as e: if str(e).startswith("File does not exist: "): return False else: raise e def upload(self, hdfs_path, local_path, overwrite=False): return self.client.upload(hdfs_path, local_path, overwrite=overwrite) def download(self, hdfs_path, local_path, overwrite=False, n_threads=-1): return self.client.download(hdfs_path, local_path, overwrite=overwrite, n_threads=n_threads) def remove(self, hdfs_path, recursive=True, skip_trash=False): assert skip_trash # Yes, you need to explicitly say skip_trash=True return self.client.delete(hdfs_path, recursive=recursive) def read(self, hdfs_path, offset=0, length=None, buffer_size=None, chunk_size=1024, buffer_char=None): return self.client.read(hdfs_path, offset=offset, length=length, buffer_size=buffer_size, chunk_size=chunk_size, buffer_char=buffer_char) def move(self, path, dest): parts = dest.rstrip("/").split("/") if len(parts) > 1: dir_path = "/".join(parts[0:-1]) if not self.exists(dir_path): self.mkdir(dir_path, parents=True) self.client.rename(path, dest) def mkdir(self, path, parents=True, mode=0o755, raise_if_exists=False): """ Has no returnvalue (just like WebHDFS) """ if not parents or raise_if_exists: warnings.warn("webhdfs mkdir: parents/raise_if_exists not implemented") permission = int(oct(mode)[2:]) # Convert from int(decimal) to int(octal) self.client.makedirs(path, permission=permission) def chmod(self, path, permissions, recursive=False): """ Raise a NotImplementedError exception. """ raise NotImplementedError("Webhdfs in luigi doesn't implement chmod") def chown(self, path, owner, group, recursive=False): """ Raise a NotImplementedError exception. """ raise NotImplementedError("Webhdfs in luigi doesn't implement chown") def count(self, path): """ Raise a NotImplementedError exception. """ raise NotImplementedError("Webhdfs in luigi doesn't implement count") def copy(self, path, destination): """ Raise a NotImplementedError exception. """ raise NotImplementedError("Webhdfs in luigi doesn't implement copy") def put(self, local_path, destination): """ Restricted version of upload """ self.upload(local_path, destination) def get(self, path, local_destination): """ Restricted version of download """ self.download(path, local_destination) def listdir(self, path, ignore_directories=False, ignore_files=False, include_size=False, include_type=False, include_time=False, recursive=False): assert not recursive return self.client.list(path, status=False) def touchz(self, path): """ To touchz using the web hdfs "write" cmd. """ self.client.write(path, data="", overwrite=False) ================================================ FILE: luigi/contrib/hive.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import abc import collections import logging import operator import os import re import subprocess import tempfile import warnings import luigi import luigi.contrib.hadoop from luigi.contrib.hdfs import get_autoconfig_client from luigi.target import FileAlreadyExists, FileSystemTarget from luigi.task import flatten logger = logging.getLogger("luigi-interface") class HiveCommandError(RuntimeError): def __init__(self, message, out=None, err=None): super(HiveCommandError, self).__init__(message, out, err) self.message = message self.out = out self.err = err def load_hive_cmd(): return luigi.configuration.get_config().get("hive", "command", "hive").split(" ") def get_hive_syntax(): return luigi.configuration.get_config().get("hive", "release", "cdh4") def get_hive_warehouse_location(): return luigi.configuration.get_config().get("hive", "warehouse_location", "/user/hive/warehouse") def get_ignored_file_masks(): return luigi.configuration.get_config().get("hive", "ignored_file_masks", None) def run_hive(args, check_return_code=True): """ Runs the `hive` from the command line, passing in the given args, and returning stdout. With the apache release of Hive, so of the table existence checks (which are done using DESCRIBE do not exit with a return code of 0 so we need an option to ignore the return code and just return stdout for parsing """ cmd = load_hive_cmd() + args p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) stdout, stderr = p.communicate() if check_return_code and p.returncode != 0: raise HiveCommandError("Hive command: {0} failed with error code: {1}".format(" ".join(cmd), p.returncode), stdout, stderr) return stdout.decode("utf-8") def run_hive_cmd(hivecmd, check_return_code=True): """ Runs the given hive query and returns stdout. """ return run_hive(["-e", hivecmd], check_return_code) def run_hive_script(script): """ Runs the contents of the given script in hive and returns stdout. """ if not os.path.isfile(script): raise RuntimeError("Hive script: {0} does not exist.".format(script)) return run_hive(["-f", script]) def _is_ordered_dict(dikt): return isinstance(dikt, (collections.OrderedDict, dict)) def _validate_partition(partition): """ If partition is set and its size is more than one and not ordered, then we're unable to restore its path in the warehouse """ if partition and len(partition) > 1 and not _is_ordered_dict(partition): raise ValueError("Unable to restore table/partition location") class HiveClient(metaclass=abc.ABCMeta): @abc.abstractmethod def table_location(self, table, database="default", partition=None): """ Returns location of db.table (or db.table.partition). partition is a dict of partition key to value. """ pass @abc.abstractmethod def table_schema(self, table, database="default"): """ Returns list of [(name, type)] for each column in database.table. """ pass @abc.abstractmethod def table_exists(self, table, database="default", partition=None): """ Returns true if db.table (or db.table.partition) exists. partition is a dict of partition key to value. """ pass @abc.abstractmethod def partition_spec(self, partition): """Turn a dict into a string partition specification""" pass class HiveCommandClient(HiveClient): """ Uses `hive` invocations to find information. """ def table_location(self, table, database="default", partition=None): cmd = "use {0}; describe formatted {1}".format(database, table) if partition is not None: cmd += " PARTITION ({0})".format(self.partition_spec(partition)) stdout = run_hive_cmd(cmd) for line in stdout.split("\n"): if "Location:" in line: return line.split("\t")[1] def table_exists(self, table, database="default", partition=None): if partition is None: stdout = run_hive_cmd('use {0}; show tables like "{1}";'.format(database, table)) return stdout and table.lower() in stdout else: stdout = run_hive_cmd( """use %s; show partitions %s partition (%s)""" % (database, table, self.partition_spec(partition)) ) if stdout: return True else: return False def table_schema(self, table, database="default"): describe = run_hive_cmd("use {0}; describe {1}".format(database, table)) if not describe or "does not exist" in describe: return None return [tuple([x.strip() for x in line.strip().split("\t")]) for line in describe.strip().split("\n")] def partition_spec(self, partition): """ Turns a dict into the a Hive partition specification string. """ return ",".join(["`{0}`='{1}'".format(k, v) for (k, v) in sorted(partition.items(), key=operator.itemgetter(0))]) class ApacheHiveCommandClient(HiveCommandClient): """ A subclass for the HiveCommandClient to (in some cases) ignore the return code from the hive command so that we can just parse the output. """ def table_schema(self, table, database="default"): describe = run_hive_cmd("use {0}; describe {1}".format(database, table), False) if not describe or "Table not found" in describe: return None return [tuple([x.strip() for x in line.strip().split("\t")]) for line in describe.strip().split("\n")] class MetastoreClient(HiveClient): def table_location(self, table, database="default", partition=None): with HiveThriftContext() as client: if partition is not None: try: import hive_metastore.ttypes partition_str = self.partition_spec(partition) thrift_table = client.get_partition_by_name(database, table, partition_str) except hive_metastore.ttypes.NoSuchObjectException: return "" else: thrift_table = client.get_table(database, table) return thrift_table.sd.location def table_exists(self, table, database="default", partition=None): with HiveThriftContext() as client: if partition is None: return table in client.get_all_tables(database) else: return partition in self._existing_partitions(table, database, client) def _existing_partitions(self, table, database, client): def _parse_partition_string(partition_string): partition_def = {} for part in partition_string.split("/"): name, value = part.split("=") partition_def[name] = value return partition_def # -1 is max_parts, the # of partition names to return (-1 = unlimited) partition_strings = client.get_partition_names(database, table, -1) return [_parse_partition_string(existing_partition) for existing_partition in partition_strings] def table_schema(self, table, database="default"): with HiveThriftContext() as client: return [(field_schema.name, field_schema.type) for field_schema in client.get_schema(database, table)] def partition_spec(self, partition): return "/".join("%s=%s" % (k, v) for (k, v) in sorted(partition.items(), key=operator.itemgetter(0))) class HiveThriftContext: """ Context manager for hive metastore client. """ def __enter__(self): try: # Note that this will only work with a CDH release. # This uses the thrift bindings generated by the ThriftHiveMetastore service in Beeswax. # If using the Apache release of Hive this import will fail. from hive_metastore import ThriftHiveMetastore from thrift.protocol import TBinaryProtocol from thrift.transport import TSocket, TTransport config = luigi.configuration.get_config() host = config.get("hive", "metastore_host") port = config.getint("hive", "metastore_port") transport = TSocket.TSocket(host, port) transport = TTransport.TBufferedTransport(transport) protocol = TBinaryProtocol.TBinaryProtocol(transport) transport.open() self.transport = transport return ThriftHiveMetastore.Client(protocol) except ImportError as e: raise Exception("Could not import Hive thrift library:" + str(e)) def __exit__(self, exc_type, exc_val, exc_tb): self.transport.close() class WarehouseHiveClient(HiveClient): """ Client for managed tables that makes decision based on presence of directory in hdfs """ def __init__(self, hdfs_client=None, warehouse_location=None): self.hdfs_client = hdfs_client or get_autoconfig_client() self.warehouse_location = warehouse_location or get_hive_warehouse_location() def table_schema(self, table, database="default"): return NotImplemented def table_location(self, table, database="default", partition=None): return os.path.join(self.warehouse_location, database + ".db", table, self.partition_spec(partition)) def table_exists(self, table, database="default", partition=None): """ The table/partition is considered existing if corresponding path in hdfs exists and contains file except those which match pattern set in `ignored_file_masks` """ path = self.table_location(table, database, partition) if self.hdfs_client.exists(path): ignored_files = get_ignored_file_masks() if ignored_files is None: return True filenames = self.hdfs_client.listdir(path) pattern = re.compile(ignored_files) for filename in filenames: if not pattern.match(filename): return True return False def partition_spec(self, partition): _validate_partition(partition) return "/".join(["{}={}".format(k, v) for (k, v) in (partition or {}).items()]) def get_default_client(): syntax = get_hive_syntax() if syntax == "apache": return ApacheHiveCommandClient() elif syntax == "metastore": return MetastoreClient() elif syntax == "warehouse": return WarehouseHiveClient() else: return HiveCommandClient() client = get_default_client() class HiveQueryTask(luigi.contrib.hadoop.BaseHadoopJobTask): """ Task to run a hive query. """ # by default, we let hive figure these out. n_reduce_tasks = None bytes_per_reducer = None reducers_max = None @abc.abstractmethod def query(self): """Text of query to run in hive""" raise RuntimeError("Must implement query!") def hiverc(self): """ Location of an rc file to run before the query if hiverc-location key is specified in luigi.cfg, will default to the value there otherwise returns None. Returning a list of rc files will load all of them in order. """ return luigi.configuration.get_config().get("hive", "hiverc-location", default=None) def hivevars(self): """ Returns a dict of key=value settings to be passed along to the hive command line via --hivevar. This option can be used as a separated namespace for script local variables. See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+VariableSubstitution """ return {} def hiveconfs(self): """ Returns a dict of key=value settings to be passed along to the hive command line via --hiveconf. By default, sets mapred.job.name to task_id and if not None, sets: * mapred.reduce.tasks (n_reduce_tasks) * mapred.fairscheduler.pool (pool) or mapred.job.queue.name (pool) * hive.exec.reducers.bytes.per.reducer (bytes_per_reducer) * hive.exec.reducers.max (reducers_max) """ jcs = {} jcs["mapred.job.name"] = "'" + self.task_id + "'" if self.n_reduce_tasks is not None: jcs["mapred.reduce.tasks"] = self.n_reduce_tasks if self.pool is not None: # Supporting two schedulers: fair (default) and capacity using the same option scheduler_type = luigi.configuration.get_config().get("hadoop", "scheduler", "fair") if scheduler_type == "fair": jcs["mapred.fairscheduler.pool"] = self.pool elif scheduler_type == "capacity": jcs["mapred.job.queue.name"] = self.pool if self.bytes_per_reducer is not None: jcs["hive.exec.reducers.bytes.per.reducer"] = self.bytes_per_reducer if self.reducers_max is not None: jcs["hive.exec.reducers.max"] = self.reducers_max return jcs def job_runner(self): return HiveQueryRunner() class HiveQueryRunner(luigi.contrib.hadoop.JobRunner): """ Runs a HiveQueryTask by shelling out to hive. """ def prepare_outputs(self, job): """ Called before job is started. If output is a `FileSystemTarget`, create parent directories so the hive command won't fail """ outputs = flatten(job.output()) for o in outputs: if isinstance(o, FileSystemTarget): parent_dir = os.path.dirname(o.path) if parent_dir and not o.fs.exists(parent_dir): logger.info("Creating parent directory %r", parent_dir) try: # there is a possible race condition # which needs to be handled here o.fs.mkdir(parent_dir) except FileAlreadyExists: pass def get_arglist(self, f_name, job): arglist = load_hive_cmd() + ["-f", f_name] hiverc = job.hiverc() if hiverc: if isinstance(hiverc, str): hiverc = [hiverc] for rcfile in hiverc: arglist += ["-i", rcfile] hiveconfs = job.hiveconfs() if hiveconfs: for k, v in hiveconfs.items(): arglist += ["--hiveconf", "{0}={1}".format(k, v)] hivevars = job.hivevars() if hivevars: for k, v in hivevars.items(): arglist += ["--hivevar", "{0}={1}".format(k, v)] logger.info(arglist) return arglist def run_job(self, job, tracking_url_callback=None): if tracking_url_callback is not None: warnings.warn("tracking_url_callback argument is deprecated, task.set_tracking_url is used instead.", DeprecationWarning) self.prepare_outputs(job) with tempfile.NamedTemporaryFile() as f: query = job.query() if isinstance(query, str): query = query.encode("utf8") f.write(query) f.flush() arglist = self.get_arglist(f.name, job) return luigi.contrib.hadoop.run_and_track_hadoop_job(arglist, job.set_tracking_url) class HivePartitionTarget(luigi.Target): """ Target representing Hive table or Hive partition """ def __init__(self, table, partition, database="default", fail_missing_table=True, client=None): """ @param table: Table name @type table: str @param partition: partition specificaton in form of dict of {"partition_column_1": "partition_value_1", "partition_column_2": "partition_value_2", ... } If `partition` is `None` or `{}` then target is Hive nonpartitioned table @param database: Database name @param fail_missing_table: flag to ignore errors raised due to table nonexistence @param client: `HiveCommandClient` instance. Default if `client is None` """ self.database = database self.table = table self.partition = partition self.client = client or get_default_client() self.fail_missing_table = fail_missing_table def __str__(self): return self.path def exists(self): """ returns `True` if the partition/table exists """ try: logger.debug("Checking Hive table '{d}.{t}' for partition {p}".format(d=self.database, t=self.table, p=str(self.partition or {}))) return self.client.table_exists(self.table, self.database, self.partition) except HiveCommandError: if self.fail_missing_table: raise else: if self.client.table_exists(self.table, self.database): # a real error occurred raise else: # oh the table just doesn't exist return False @property def path(self): """ Returns the path for this HiveTablePartitionTarget's data. """ location = self.client.table_location(self.table, self.database, self.partition) if not location: raise Exception("Couldn't find location for table: {0}".format(str(self))) return location class HiveTableTarget(HivePartitionTarget): """ Target representing non-partitioned table """ def __init__(self, table, database="default", client=None): super(HiveTableTarget, self).__init__( table=table, partition=None, database=database, fail_missing_table=False, client=client, ) class ExternalHiveTask(luigi.ExternalTask): """ External task that depends on a Hive table/partition. """ database = luigi.Parameter(default="default") table = luigi.Parameter() partition: luigi.DictParameter = luigi.DictParameter( default={}, description='Python dictionary specifying the target partition e.g. {"date": "2013-01-25"}' ) def output(self): return HivePartitionTarget( database=self.database, table=self.table, partition=self.partition, ) ================================================ FILE: luigi/contrib/kubernetes.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2015 Outlier Bio, LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ Kubernetes Job wrapper for Luigi. From the Kubernetes website: Kubernetes is an open-source system for automating deployment, scaling, and management of containerized applications. For more information about Kubernetes Jobs: http://kubernetes.io/docs/user-guide/jobs/ Requires: - pykube: ``pip install pykube-ng`` Written and maintained by Marco Capuccini (@mcapuccini). """ import logging import time import uuid from datetime import datetime import luigi logger = logging.getLogger("luigi-interface") try: from pykube.config import KubeConfig from pykube.http import HTTPClient from pykube.objects import Job, Pod except ImportError: logger.warning("pykube is not installed. KubernetesJobTask requires pykube.") class kubernetes(luigi.Config): auth_method = luigi.Parameter(default="kubeconfig", description="Authorization method to access the cluster") kubeconfig_path = luigi.Parameter(default="~/.kube/config", description="Path to kubeconfig file for cluster authentication") max_retrials = luigi.IntParameter(default=0, description="Max retrials in event of job failure") kubernetes_namespace = luigi.OptionalParameter(default=None, description="K8s namespace in which the job will run") class KubernetesJobTask(luigi.Task): __DEFAULT_POLL_INTERVAL = 5 # see __track_job __DEFAULT_POD_CREATION_INTERVAL = 5 _kubernetes_config = None # Needs to be loaded at runtime def _init_kubernetes(self): self.__logger = logger self.__logger.debug("Kubernetes auth method: " + self.auth_method) if self.auth_method == "kubeconfig": self.__kube_api = HTTPClient(KubeConfig.from_file(self.kubeconfig_path)) elif self.auth_method == "service-account": self.__kube_api = HTTPClient(KubeConfig.from_service_account()) else: raise ValueError("Illegal auth_method") self.job_uuid = str(uuid.uuid4().hex) now = datetime.utcnow() self.uu_name = "%s-%s-%s" % (self.name, now.strftime("%Y%m%d%H%M%S"), self.job_uuid[:16]) @property def auth_method(self): """ This can be set to ``kubeconfig`` or ``service-account``. It defaults to ``kubeconfig``. For more details, please refer to: - kubeconfig: http://kubernetes.io/docs/user-guide/kubeconfig-file - service-account: http://kubernetes.io/docs/user-guide/service-accounts """ return self.kubernetes_config.auth_method @property def kubeconfig_path(self): """ Path to kubeconfig file used for cluster authentication. It defaults to "~/.kube/config", which is the default location when using minikube (http://kubernetes.io/docs/getting-started-guides/minikube). When auth_method is ``service-account`` this property is ignored. **WARNING**: For Python versions < 3.5 kubeconfig must point to a Kubernetes API hostname, and NOT to an IP address. For more details, please refer to: http://kubernetes.io/docs/user-guide/kubeconfig-file """ return self.kubernetes_config.kubeconfig_path @property def kubernetes_namespace(self): """ Namespace in Kubernetes where the job will run. It defaults to the default namespace in Kubernetes For more details, please refer to: https://kubernetes.io/docs/concepts/overview/working-with-objects/namespaces/ """ return self.kubernetes_config.kubernetes_namespace @property def name(self): """ A name for this job. This task will automatically append a UUID to the name before to submit to Kubernetes. """ raise NotImplementedError("subclass must define name") @property def labels(self): """ Return custom labels for kubernetes job. example:: ``{"run_dt": datetime.date.today().strftime('%F')}`` """ return {} @property def spec_schema(self): """ Kubernetes Job spec schema in JSON format, an example follows. .. code-block:: javascript { "containers": [{ "name": "pi", "image": "perl", "command": ["perl", "-Mbignum=bpi", "-wle", "print bpi(2000)"] }], "restartPolicy": "Never" } **restartPolicy** - If restartPolicy is not defined, it will be set to "Never" by default. - **Warning**: restartPolicy=OnFailure will bypass max_retrials, and restart the container until success, with the risk of blocking the Luigi task. For more informations please refer to: http://kubernetes.io/docs/user-guide/pods/multi-container/#the-spec-schema """ raise NotImplementedError("subclass must define spec_schema") @property def max_retrials(self): """ Maximum number of retrials in case of failure. """ return self.kubernetes_config.max_retrials @property def backoff_limit(self): """ Maximum number of retries before considering the job as failed. See: https://kubernetes.io/docs/concepts/workloads/controllers/jobs-run-to-completion/#pod-backoff-failure-policy """ return 6 @property def delete_on_success(self): """ Delete the Kubernetes workload if the job has ended successfully. """ return True @property def print_pod_logs_on_exit(self): """ Fetch and print the pod logs once the job is completed. """ return False @property def active_deadline_seconds(self): """ Time allowed to successfully schedule pods. See: https://kubernetes.io/docs/concepts/workloads/controllers/jobs-run-to-completion/#job-termination-and-cleanup """ return None @property def kubernetes_config(self): if not self._kubernetes_config: self._kubernetes_config = kubernetes() return self._kubernetes_config @property def poll_interval(self): """How often to poll Kubernetes for job status, in seconds.""" return self.__DEFAULT_POLL_INTERVAL @property def pod_creation_wait_interal(self): """Delay for initial pod creation for just submitted job in seconds""" return self.__DEFAULT_POD_CREATION_INTERVAL def __track_job(self): """Poll job status while active""" while not self.__verify_job_has_started(): time.sleep(self.poll_interval) self.__logger.debug("Waiting for Kubernetes job " + self.uu_name + " to start") self.__print_kubectl_hints() status = self.__get_job_status() while status == "RUNNING": self.__logger.debug("Kubernetes job " + self.uu_name + " is running") time.sleep(self.poll_interval) status = self.__get_job_status() assert status != "FAILED", "Kubernetes job " + self.uu_name + " failed" # status == "SUCCEEDED" self.__logger.info("Kubernetes job " + self.uu_name + " succeeded") self.signal_complete() def signal_complete(self): """Signal job completion for scheduler and dependent tasks. Touching a system file is an easy way to signal completion. example:: .. code-block:: python with self.output().open('w') as output_file: output_file.write('') """ pass def __get_pods(self): pod_objs = Pod.objects(self.__kube_api, namespace=self.kubernetes_namespace).filter(selector="job-name=" + self.uu_name).response["items"] return [Pod(self.__kube_api, p) for p in pod_objs] def __get_job(self): jobs = Job.objects(self.__kube_api, namespace=self.kubernetes_namespace).filter(selector="luigi_task_id=" + self.job_uuid).response["items"] assert len(jobs) == 1, "Kubernetes job " + self.uu_name + " not found" return Job(self.__kube_api, jobs[0]) def __print_pod_logs(self): for pod in self.__get_pods(): logs = pod.logs(timestamps=True).strip() self.__logger.info("Fetching logs from " + pod.name) if len(logs) > 0: for line in logs.split("\n"): self.__logger.info(line) def __print_kubectl_hints(self): self.__logger.info("To stream Pod logs, use:") for pod in self.__get_pods(): self.__logger.info("`kubectl logs -f pod/%s -n %s`" % (pod.name, pod.namespace)) def __verify_job_has_started(self): """Asserts that the job has successfully started""" # Verify that the job started self.__get_job() # Verify that the pod started pods = self.__get_pods() if not pods: self.__logger.debug("No pods found for %s, waiting for cluster state to match the job definition" % self.uu_name) time.sleep(self.pod_creation_wait_interal) pods = self.__get_pods() assert len(pods) > 0, "No pod scheduled by " + self.uu_name for pod in pods: status = pod.obj["status"] for cont_stats in status.get("containerStatuses", []): if "terminated" in cont_stats["state"]: t = cont_stats["state"]["terminated"] err_msg = "Pod %s %s (exit code %d). Logs: `kubectl logs pod/%s`" % (pod.name, t["reason"], t["exitCode"], pod.name) assert t["exitCode"] == 0, err_msg if "waiting" in cont_stats["state"]: wr = cont_stats["state"]["waiting"]["reason"] assert wr == "ContainerCreating", "Pod %s %s. Logs: `kubectl logs pod/%s`" % (pod.name, wr, pod.name) for cond in status.get("conditions", []): if "message" in cond: if cond["reason"] == "ContainersNotReady": return False assert cond["status"] != "False", "[ERROR] %s - %s" % (cond["reason"], cond["message"]) return True def __get_job_status(self): """Return the Kubernetes job status""" # Figure out status and return it job = self.__get_job() if "succeeded" in job.obj["status"] and job.obj["status"]["succeeded"] > 0: job.scale(replicas=0) if self.print_pod_logs_on_exit: self.__print_pod_logs() if self.delete_on_success: self.__delete_job_cascade(job) return "SUCCEEDED" if "failed" in job.obj["status"]: failed_cnt = job.obj["status"]["failed"] self.__logger.debug("Kubernetes job " + self.uu_name + " status.failed: " + str(failed_cnt)) if self.print_pod_logs_on_exit: self.__print_pod_logs() if failed_cnt > self.max_retrials: job.scale(replicas=0) # avoid more retrials return "FAILED" return "RUNNING" def __delete_job_cascade(self, job): delete_options_cascade = {"kind": "DeleteOptions", "apiVersion": "v1", "propagationPolicy": "Background"} r = self.__kube_api.delete(json=delete_options_cascade, **job.api_kwargs()) if r.status_code != 200: self.__kube_api.raise_for_status(r) def run(self): self._init_kubernetes() # Render job job_json = { "apiVersion": "batch/v1", "kind": "Job", "metadata": {"name": self.uu_name, "labels": {"spawned_by": "luigi", "luigi_task_id": self.job_uuid}}, "spec": {"backoffLimit": self.backoff_limit, "template": {"metadata": {"name": self.uu_name, "labels": {}}, "spec": self.spec_schema}}, } if self.kubernetes_namespace is not None: job_json["metadata"]["namespace"] = self.kubernetes_namespace if self.active_deadline_seconds is not None: job_json["spec"]["activeDeadlineSeconds"] = self.active_deadline_seconds # Update user labels job_json["metadata"]["labels"].update(self.labels) job_json["spec"]["template"]["metadata"]["labels"].update(self.labels) # Add default restartPolicy if not specified if "restartPolicy" not in self.spec_schema: job_json["spec"]["template"]["spec"]["restartPolicy"] = "Never" # Submit job self.__logger.info("Submitting Kubernetes Job: " + self.uu_name) job = Job(self.__kube_api, job_json) job.create() # Track the Job (wait while active) self.__logger.info("Start tracking Kubernetes Job: " + self.uu_name) self.__track_job() def output(self): """ An output target is necessary for checking job completion unless an alternative complete method is defined. Example:: return luigi.LocalTarget(os.path.join('/tmp', 'example')) """ pass ================================================ FILE: luigi/contrib/lsf.py ================================================ # -*- coding: utf-8 -*- """ .. Copyright 2012-2015 Spotify AB Copyright 2018 Copyright 2018 EMBL-European Bioinformatics Institute Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import logging import os import random import shutil import subprocess import sys import time try: # Dill is used for handling pickling and unpickling if there is a deference # in server setups between the LSF submission node and the nodes in the # cluster import dill as pickle except ImportError: import pickle import luigi import luigi.configuration from luigi.contrib import lsf_runner from luigi.contrib.hadoop import create_packages_archive from luigi.task_status import DONE, FAILED, PENDING, RUNNING, UNKNOWN """ LSF batch system Tasks. ======================= What's LSF? see http://en.wikipedia.org/wiki/Platform_LSF and https://wiki.med.harvard.edu/Orchestra/IntroductionToLSF See: https://github.com/spotify/luigi/issues/1936 This extension is modeled after the hadoop.py approach. I'll be making a few assumptions, and will try to note them. Going into it, the assumptions are: - You schedule your jobs on an LSF submission node. - The 'bjobs' command on an LSF batch submission system returns a standardized format. - All nodes have access to the code you're running. - The sysadmin won't get pissed if we run a 'bjobs' check every thirty seconds or so per job (there are ways of coalescing the bjobs calls if that's not cool). The procedure: - Pickle the class - Construct a bsub argument that runs a generic runner function with the path to the pickled class - Runner function loads the class from pickle - Runner function hits the work button on it """ LOGGER = logging.getLogger("luigi-interface") def track_job(job_id): """ Tracking is done by requesting each job and then searching for whether the job has one of the following states: - "RUN", - "PEND", - "SSUSP", - "EXIT" based on the LSF documentation """ cmd = ["bjobs", "-noheader", "-o", "stat", str(job_id)] track_job_proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, shell=False) status = track_job_proc.communicate()[0].strip("\n") return status def kill_job(job_id): """ Kill a running LSF job """ subprocess.call(["bkill", job_id]) class LSFJobTask(luigi.Task): """ Takes care of uploading and executing an LSF job """ n_cpu_flag = luigi.IntParameter(default=2, significant=False) shared_tmp_dir = luigi.Parameter(default="/tmp", significant=False) resource_flag = luigi.Parameter(default="mem=8192", significant=False) memory_flag = luigi.Parameter(default="8192", significant=False) queue_flag = luigi.Parameter(default="queue_name", significant=False) runtime_flag = luigi.IntParameter(default=60) job_name_flag = luigi.Parameter(default="") poll_time = luigi.FloatParameter(significant=False, default=5, description="specify the wait time to poll bjobs for the job status") save_job_info = luigi.BoolParameter(default=False) output = luigi.Parameter(default="") extra_bsub_args = luigi.Parameter(default="") job_status = None def fetch_task_failures(self): """ Read in the error file from bsub """ error_file = os.path.join(self.tmp_dir, "job.err") if os.path.isfile(error_file): with open(error_file, "r") as f_err: errors = f_err.readlines() else: errors = "" return errors def fetch_task_output(self): """ Read in the output file """ # Read in the output file if os.path.isfile(os.path.join(self.tmp_dir, "job.out")): with open(os.path.join(self.tmp_dir, "job.out"), "r") as f_out: outputs = f_out.readlines() else: outputs = "" return outputs def _init_local(self): base_tmp_dir = self.shared_tmp_dir random_id = "%016x" % random.getrandbits(64) task_name = random_id + self.task_id # If any parameters are directories, if we don't # replace the separators on *nix, it'll create a weird nested directory task_name = task_name.replace("/", "::") # Max filename length max_filename_length = os.fstatvfs(0).f_namemax self.tmp_dir = os.path.join(base_tmp_dir, task_name[:max_filename_length]) LOGGER.info("Tmp dir: %s", self.tmp_dir) os.makedirs(self.tmp_dir) # Dump the code to be run into a pickle file LOGGER.debug("Dumping pickled class") self._dump(self.tmp_dir) # Make sure that all the class's dependencies are tarred and available LOGGER.debug("Tarballing dependencies") # Grab luigi and the module containing the code to be run packages = [luigi, __import__(self.__module__, None, None, "dummy")] create_packages_archive(packages, os.path.join(self.tmp_dir, "packages.tar")) # Now, pass onto the class's specified init_local() method. self.init_local() def init_local(self): """ Implement any work to setup any internal datastructure etc here. You can add extra input using the requires_local/input_local methods. Anything you set on the object will be pickled and available on the compute nodes. """ pass def run(self): """ The procedure: - Pickle the class - Tarball the dependencies - Construct a bsub argument that runs a generic runner function with the path to the pickled class - Runner function loads the class from pickle - Runner class untars the dependencies - Runner function hits the button on the class's work() method """ self._init_local() self._run_job() def work(self): """ Subclass this for where you're doing your actual work. Why not run(), like other tasks? Because we need run to always be something that the Worker can call, and that's the real logical place to do LSF scheduling. So, the work will happen in work(). """ pass def _dump(self, out_dir=""): """ Dump instance to file. """ self.job_file = os.path.join(out_dir, "job-instance.pickle") if self.__module__ == "__main__": dump_inst = pickle.dumps(self) module_name = os.path.basename(sys.argv[0]).rsplit(".", 1)[0] dump_inst = dump_inst.replace("(c__main__", "(c" + module_name) open(self.job_file, "w").write(dump_inst) else: pickle.dump(self, open(self.job_file, "w")) def _run_job(self): """ Build a bsub argument that will run lsf_runner.py on the directory we've specified. """ args = [] if isinstance(self.output(), list): log_output = os.path.split(self.output()[0].path) else: log_output = os.path.split(self.output().path) args += ["bsub", "-q", self.queue_flag] args += ["-n", str(self.n_cpu_flag)] args += ["-M", str(self.memory_flag)] args += ["-R", "rusage[%s]" % self.resource_flag] args += ["-W", str(self.runtime_flag)] if self.job_name_flag: args += ["-J", str(self.job_name_flag)] args += ["-o", os.path.join(log_output[0], "job.out")] args += ["-e", os.path.join(log_output[0], "job.err")] if self.extra_bsub_args: args += self.extra_bsub_args.split() # Find where the runner file is runner_path = os.path.abspath(lsf_runner.__file__) args += [runner_path] args += [self.tmp_dir] # That should do it. Let the world know what we're doing. LOGGER.info("### LSF SUBMISSION ARGS: %s", " ".join([str(a) for a in args])) # Submit the job run_job_proc = subprocess.Popen([str(a) for a in args], stdin=subprocess.PIPE, stdout=subprocess.PIPE, cwd=self.tmp_dir) output = run_job_proc.communicate()[0] # ASSUMPTION # The result will be of the format # Job <123> is submitted ot queue # So get the number in those first brackets. # I cannot think of a better workaround that leaves logic on the Task side of things. LOGGER.info("### JOB SUBMISSION OUTPUT: %s", str(output)) self.job_id = int(output.split("<")[1].split(">")[0]) LOGGER.info("Job %ssubmitted as job %s", self.job_name_flag + " ", str(self.job_id)) self._track_job() # If we want to save the job temporaries, then do so # We'll move them to be next to the job output if self.save_job_info: LOGGER.info("Saving up temporary bits") # dest_dir = self.output().path shutil.move(self.tmp_dir, "/".join(log_output[0:-1])) # Now delete the temporaries, if they're there. self._finish() def _track_job(self): time0 = 0 while True: # Sleep for a little bit time.sleep(self.poll_time) # See what the job's up to # ASSUMPTION lsf_status = track_job(self.job_id) if lsf_status == "RUN": self.job_status = RUNNING LOGGER.info("Job is running...") if time0 == 0: time0 = int(round(time.time())) elif lsf_status == "PEND": self.job_status = PENDING LOGGER.info("Job is pending...") elif lsf_status == "DONE" or lsf_status == "EXIT": # Then the job could either be failed or done. errors = self.fetch_task_failures() if not errors: self.job_status = DONE LOGGER.info("Job is done") time1 = int(round(time.time())) # Return a near estimate of the run time to with +/- the # self.poll_time job_name = str(self.job_id) if self.job_name_flag: job_name = "%s %s" % (self.job_name_flag, job_name) LOGGER.info("### JOB COMPLETED: %s in %s seconds", job_name, str(time1 - time0)) else: self.job_status = FAILED LOGGER.error("Job has FAILED") LOGGER.error("\n\n") LOGGER.error("Traceback: ") for error in errors: LOGGER.error(error) break elif lsf_status == "SSUSP": self.job_status = PENDING LOGGER.info("Job is suspended (basically, pending)...") else: self.job_status = UNKNOWN LOGGER.info("Job status is UNKNOWN!") LOGGER.info("Status is : %s", lsf_status) break def _finish(self): LOGGER.info("Cleaning up temporary bits") if self.tmp_dir and os.path.exists(self.tmp_dir): LOGGER.info("Removing directory %s", self.tmp_dir) shutil.rmtree(self.tmp_dir) def __del__(self): pass # self._finish() class LocalLSFJobTask(LSFJobTask): """ A local version of JobTask, for easier debugging. """ def run(self): self.init_local() self.work() ================================================ FILE: luigi/contrib/lsf_runner.py ================================================ # -*- coding: utf-8 -*- """ .. Copyright 2012-2015 Spotify AB Copyright 2018 Copyright 2018 EMBL-European Bioinformatics Institute Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import os import sys try: # Dill is used for handling pickling and unpickling if there is a deference # in server setups between the LSF submission node and the nodes in the # cluster import dill as pickle except ImportError: import pickle import logging from luigi.safe_extractor import SafeExtractor def do_work_on_compute_node(work_dir): # Extract the necessary dependencies extract_packages_archive(work_dir) # Open up the pickle file with the work to be done os.chdir(work_dir) with open("job-instance.pickle", "r") as pickle_file_handle: job = pickle.load(pickle_file_handle) # Do the work contained job.work() def extract_packages_archive(work_dir): package_file = os.path.join(work_dir, "packages.tar") if not os.path.exists(package_file): return curdir = os.path.abspath(os.curdir) os.chdir(work_dir) extractor = SafeExtractor(work_dir) extractor.safe_extract(package_file) if "" not in sys.path: sys.path.insert(0, "") os.chdir(curdir) def main(args=sys.argv): """Run the work() method from the class instance in the file "job-instance.pickle".""" try: # Set up logging. logging.basicConfig(level=logging.WARN) work_dir = args[1] assert os.path.exists(work_dir), "First argument to lsf_runner.py must be a directory that exists" do_work_on_compute_node(work_dir) except Exception as exc: # Dump encoded data that we will try to fetch using mechanize print(exc) raise if __name__ == "__main__": main() ================================================ FILE: luigi/contrib/mongodb.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2017 Big Datext Inc # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from luigi.target import Target class MongoTarget(Target): """Target for a resource in MongoDB""" def __init__(self, mongo_client, index, collection): """ :param mongo_client: MongoClient instance :type mongo_client: MongoClient :param index: database index :type index: str :param collection: index collection :type collection: str """ self._mongo_client = mongo_client self._index = index self._collection = collection def __str__(self): return f"{self._index}/{self._collection}" def get_collection(self): """ Return targeted mongo collection to query on """ db_mongo = self._mongo_client[self._index] return db_mongo[self._collection] def get_index(self): """ Return targeted mongo index to query on """ return self._mongo_client[self._index] class MongoCellTarget(MongoTarget): """Target for a ressource in a specific field from a MongoDB document""" def __init__(self, mongo_client, index, collection, document_id, path): """ :param document_id: targeted mongo document :type document_id: str :param path: full path to the targeted field in the mongo document :type path: str """ super(MongoCellTarget, self).__init__(mongo_client, index, collection) self._document_id = document_id self._path = path def exists(self): """ Test if target has been run Target is considered run if the targeted field exists """ return self.read() is not None def read(self): """ Read the target value Use $project aggregate operator in order to support nested objects """ result = self.get_collection().aggregate([{"$match": {"_id": self._document_id}}, {"$project": {"_value": "$" + self._path, "_id": False}}]) for doc in result: if "_value" not in doc: break return doc["_value"] def write(self, value): """ Write value to the target """ self.get_collection().update_one({"_id": self._document_id}, {"$set": {self._path: value}}, upsert=True) class MongoRangeTarget(MongoTarget): """Target for a level 0 field in a range of documents""" def __init__(self, mongo_client, index, collection, document_ids, field): """ :param document_ids: targeted mongo documents :type documents_ids: list of str :param field: targeted field in documents :type field: str """ super(MongoRangeTarget, self).__init__(mongo_client, index, collection) self._document_ids = document_ids self._field = field def exists(self): """ Test if target has been run Target is considered run if the targeted field exists in ALL documents """ return not self.get_empty_ids() def read(self): """ Read the targets value """ cursor = self.get_collection().find({"_id": {"$in": self._document_ids}, self._field: {"$exists": True}}, {self._field: True}) return {doc["_id"]: doc[self._field] for doc in cursor} def write(self, values): """ Write values to the targeted documents Values need to be a dict as : {document_id: value} """ # Insert only for docs targeted by the target filtered = {_id: value for _id, value in values.items() if _id in self._document_ids} if not filtered: return bulk = self.get_collection().initialize_ordered_bulk_op() for _id, value in filtered.items(): bulk.find({"_id": _id}).upsert().update_one({"$set": {self._field: value}}) bulk.execute() def get_empty_ids(self): """ Get documents id with missing targeted field """ cursor = self.get_collection().find({"_id": {"$in": self._document_ids}, self._field: {"$exists": True}}, {"_id": True}) return set(self._document_ids) - {doc["_id"] for doc in cursor} class MongoCollectionTarget(MongoTarget): """Target for existing collection""" def __init__(self, mongo_client, index, collection): super(MongoCollectionTarget, self).__init__(mongo_client, index, collection) def exists(self): """ Test if target has been run Target is considered run if the targeted collection exists in the database """ return self.read() def read(self): """ Return if the target collection exists in the database """ return self._collection in self.get_index().collection_names() class MongoCountTarget(MongoTarget): """Target for documents count""" def __init__(self, mongo_client, index, collection, target_count): """ :param target_count: Value of the desired item count in the target :type field: int """ super(MongoCountTarget, self).__init__(mongo_client, index, collection) self._target_count = target_count def exists(self): """ Test if the target has been run Target is considered run if the number of items in the target matches value of self._target_count """ return self.read() == self._target_count def read(self): """ Using the aggregate method to avoid inaccurate count if using a sharded cluster https://docs.mongodb.com/manual/reference/method/db.collection.count/#behavior """ for res in self.get_collection().aggregate([{"$group": {"_id": None, "count": {"$sum": 1}}}]): return res.get("count", None) return None ================================================ FILE: luigi/contrib/mssqldb.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import logging import luigi logger = logging.getLogger("luigi-interface") try: from pymssql import _mssql except ImportError: logger.warning( "Loading MSSQL module without the python package pymssql. \ This will crash at runtime if SQL Server functionality is used." ) class MSSqlTarget(luigi.Target): """ Target for a resource in Microsoft SQL Server. This module is primarily derived from mysqldb.py. Much of MSSqlTarget, MySqlTarget and PostgresTarget are similar enough to potentially add a RDBMSTarget abstract base class to rdbms.py that these classes could be derived from. """ marker_table = luigi.configuration.get_config().get("mssql", "marker-table", "table_updates") def __init__(self, host, database, user, password, table, update_id): """ Initializes a MsSqlTarget instance. :param host: MsSql server address. Possibly a host:port string. :type host: str :param database: database name. :type database: str :param user: database user :type user: str :param password: password for specified user. :type password: str :param update_id: an identifier for this data set. :type update_id: str """ if ":" in host: self.host, self.port = host.split(":") self.port = int(self.port) else: self.host = host self.port = 1433 self.database = database self.user = user self.password = password self.table = table self.update_id = update_id def __str__(self): return self.table def touch(self, connection=None): """ Mark this update as complete. IMPORTANT, If the marker table doesn't exist, the connection transaction will be aborted and the connection reset. Then the marker table will be created. """ self.create_marker_table() if connection is None: connection = self.connect() connection.execute_non_query( """IF NOT EXISTS(SELECT 1 FROM {marker_table} WHERE update_id = %(update_id)s) INSERT INTO {marker_table} (update_id, target_table) VALUES (%(update_id)s, %(table)s) ELSE UPDATE t SET target_table = %(table)s , inserted = GETDATE() FROM {marker_table} t WHERE update_id = %(update_id)s """.format(marker_table=self.marker_table), {"update_id": self.update_id, "table": self.table}, ) # make sure update is properly marked assert self.exists(connection) def exists(self, connection=None): if connection is None: connection = self.connect() try: row = connection.execute_row( """SELECT 1 FROM {marker_table} WHERE update_id = %s """.format(marker_table=self.marker_table), (self.update_id,), ) except _mssql.MssqlDatabaseException as e: # Error number for table doesn't exist if e.number == 208: row = None else: raise return row is not None def connect(self): """ Create a SQL Server connection and return a connection object """ connection = _mssql.connect(user=self.user, password=self.password, server=self.host, port=self.port, database=self.database) return connection def create_marker_table(self): """ Create marker table if it doesn't exist. Use a separate connection since the transaction might have to be reset. """ connection = self.connect() try: connection.execute_non_query( """ CREATE TABLE {marker_table} ( id BIGINT NOT NULL IDENTITY(1,1), update_id VARCHAR(128) NOT NULL, target_table VARCHAR(128), inserted DATETIME DEFAULT(GETDATE()), PRIMARY KEY (update_id) ) """.format(marker_table=self.marker_table) ) except _mssql.MssqlDatabaseException as e: # Table already exists code if e.number == 2714: pass else: raise connection.close() ================================================ FILE: luigi/contrib/mysqldb.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import logging import luigi from luigi.contrib import rdbms logger = logging.getLogger("luigi-interface") try: import mysql.connector from mysql.connector import Error, errorcode except ImportError: logger.warning( "Loading MySQL module without the python package mysql-connector-python. \ This will crash at runtime if MySQL functionality is used." ) class MySqlTarget(luigi.Target): """ Target for a resource in MySql. """ marker_table = luigi.configuration.get_config().get("mysql", "marker-table", "table_updates") def __init__(self, host, database, user, password, table, update_id, **cnx_kwargs): """ Initializes a MySqlTarget instance. :param host: MySql server address. Possibly a host:port string. :type host: str :param database: database name. :type database: str :param user: database user :type user: str :param password: password for specified user. :type password: str :param update_id: an identifier for this data set. :type update_id: str :param cnx_kwargs: optional params for mysql connector constructor. See https://dev.mysql.com/doc/connector-python/en/connector-python-connectargs.html. """ if ":" in host: self.host, self.port = host.split(":") self.port = int(self.port) else: self.host = host self.port = 3306 self.database = database self.user = user self.password = password self.table = table self.update_id = update_id self.cnx_kwargs = cnx_kwargs def __str__(self): return self.table def touch(self, connection=None): """ Mark this update as complete. IMPORTANT, If the marker table doesn't exist, the connection transaction will be aborted and the connection reset. Then the marker table will be created. """ self.create_marker_table() if connection is None: connection = self.connect() connection.autocommit = True # if connection created here, we commit it here connection.cursor().execute( """INSERT INTO {marker_table} (update_id, target_table) VALUES (%s, %s) ON DUPLICATE KEY UPDATE update_id = VALUES(update_id) """.format(marker_table=self.marker_table), (self.update_id, self.table), ) # make sure update is properly marked assert self.exists(connection) def exists(self, connection=None): if connection is None: connection = self.connect() connection.autocommit = True cursor = connection.cursor() try: cursor.execute( """SELECT 1 FROM {marker_table} WHERE update_id = %s LIMIT 1""".format(marker_table=self.marker_table), (self.update_id,), ) row = cursor.fetchone() except mysql.connector.Error as e: if e.errno == errorcode.ER_NO_SUCH_TABLE: row = None else: raise return row is not None def connect(self, autocommit=False): connection = mysql.connector.connect( user=self.user, password=self.password, host=self.host, port=self.port, database=self.database, autocommit=autocommit, **self.cnx_kwargs ) return connection def create_marker_table(self): """ Create marker table if it doesn't exist. Using a separate connection since the transaction might have to be reset. """ connection = self.connect(autocommit=True) cursor = connection.cursor() try: cursor.execute( """ CREATE TABLE {marker_table} ( id BIGINT(20) NOT NULL AUTO_INCREMENT, update_id VARCHAR(128) NOT NULL, target_table VARCHAR(128), inserted TIMESTAMP DEFAULT NOW(), PRIMARY KEY (update_id), KEY id (id) ) """.format(marker_table=self.marker_table) ) except mysql.connector.Error as e: if e.errno == errorcode.ER_TABLE_EXISTS_ERROR: pass else: raise connection.close() class CopyToTable(rdbms.CopyToTable): """ Template task for inserting a data set into MySQL Usage: Subclass and override the required `host`, `database`, `user`, `password`, `table` and `columns` attributes. To customize how to access data from an input task, override the `rows` method with a generator that yields each row as a tuple with fields ordered according to `columns`. """ def rows(self): """ Return/yield tuples or lists corresponding to each row to be inserted. """ with self.input().open("r") as fobj: for line in fobj: yield line.strip("\n").split("\t") # everything below will rarely have to be overridden def output(self): """ Returns a MySqlTarget representing the inserted dataset. Normally you don't override this. """ return MySqlTarget(host=self.host, database=self.database, user=self.user, password=self.password, table=self.table, update_id=self.update_id) def copy(self, cursor, file=None): values = "({})".format(",".join(["%s" for i in range(len(self.columns))])) columns = "({})".format(",".join([c[0] for c in self.columns])) query = "INSERT INTO {} {} VALUES {}".format(self.table, columns, values) rows = [] for idx, row in enumerate(self.rows()): rows.append(row) if (idx + 1) % self.bulk_size == 0: cursor.executemany(query, rows) rows = [] cursor.executemany(query, rows) def run(self): """ Inserts data generated by rows() into target table. If the target table doesn't exist, self.create_table will be called to attempt to create the table. Normally you don't want to override this. """ if not (self.table and self.columns): raise Exception("table and columns need to be specified") connection = self.output().connect() # attempt to copy the data into mysql # if it fails because the target table doesn't exist # try to create it by running self.create_table for attempt in range(2): try: cursor = connection.cursor() print("caling init copy...") self.init_copy(connection) self.copy(cursor) self.post_copy(connection) if self.enable_metadata_columns: self.post_copy_metacolumns(cursor) except Error as err: if err.errno == errorcode.ER_NO_SUCH_TABLE and attempt == 0: # if first attempt fails with "relation not found", try creating table # logger.info("Creating table %s", self.table) connection.reconnect() self.create_table(connection) else: raise else: break # mark as complete in same transaction self.output().touch(connection) connection.commit() connection.close() @property def bulk_size(self): return 10000 ================================================ FILE: luigi/contrib/opener.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """OpenerTarget support, allows easier testing and configuration by abstracting out the LocalTarget, S3Target, and MockTarget types. Example: .. code-block:: python from luigi.contrib.opener import OpenerTarget OpenerTarget('/local/path.txt') OpenerTarget('s3://zefr/remote/path.txt') """ from __future__ import annotations import json from urllib.parse import parse_qs, urlsplit from luigi.contrib.s3 import S3Target from luigi.local_target import LocalTarget from luigi.mock import MockTarget from luigi.target import FileSystemException __all__ = ["OpenerError", "NoOpenerError", "InvalidQuery", "OpenerRegistry", "Opener", "MockOpener", "LocalOpener", "S3Opener", "opener", "OpenerTarget"] class OpenerError(FileSystemException): """The base exception thrown by openers""" pass class NoOpenerError(OpenerError): """Thrown when there is no opener for the given protocol""" pass class InvalidQuery(OpenerError): """Thrown when an opener is passed unexpected arguments""" pass class OpenerRegistry: def __init__(self, openers=None): """An opener registry that stores a number of opener objects used to parse Target URIs :param openers: A list of objects inherited from the Opener class. :type openers: list """ if openers is None: openers = [] self.registry = {} self.openers = {} self.default_opener = "file" for opener in openers: self.add(opener) def get_opener(self, name): """Retrieve an opener for the given protocol :param name: name of the opener to open :type name: string :raises NoOpenerError: if no opener has been registered of that name """ if name not in self.registry: raise NoOpenerError("No opener for %s" % name) index = self.registry[name] return self.openers[index] def add(self, opener): """Adds an opener to the registry :param opener: Opener object :type opener: Opener inherited object """ index = len(self.openers) self.openers[index] = opener for name in opener.names: self.registry[name] = index def open(self, target_uri, **kwargs): """Open target uri. :param target_uri: Uri to open :type target_uri: string :returns: Target object """ target = urlsplit(target_uri, scheme=self.default_opener) opener = self.get_opener(target.scheme) query = opener.conform_query(target.query) target = opener.get_target(target.scheme, target.path, target.fragment, target.username, target.password, target.hostname, target.port, query, **kwargs) target.opener_path = target_uri return target class Opener: """Base class for Opener objects.""" # Dictionary of expected kwargs and flag for json loading values (bool/int) allowed_kwargs: dict[str, bool] = {} # Flag to filter out unexpected kwargs filter_kwargs = True @classmethod def conform_query(cls, query): """Converts the query string from a target uri, uses cls.allowed_kwargs, and cls.filter_kwargs to drive logic. :param query: Unparsed query string :type query: urllib.parse.unsplit(uri).query :returns: Dictionary of parsed values, everything in cls.allowed_kwargs with values set to True will be parsed as json strings. """ query = parse_qs(query, keep_blank_values=True) # Remove any unexpected keywords from the query string. if cls.filter_kwargs: query = {x: y for x, y in query.items() if x in cls.allowed_kwargs} for key, vals in query.items(): # Multiple values of the same name could be passed use first # Also params without strings will be treated as true values if cls.allowed_kwargs.get(key, False): val = json.loads(vals[0] or "true") else: val = vals[0] or "true" query[key] = val return query @classmethod def get_target(cls, scheme, path, fragment, username, password, hostname, port, query, **kwargs): """Override this method to use values from the parsed uri to initialize the expected target. """ raise NotImplementedError("get_target must be overridden") class MockOpener(Opener): """Mock target opener, works like LocalTarget but files are all in memory. example: * mock://foo/bar.txt """ names = ["mock"] allowed_kwargs = { "is_tmp": True, "mirror_on_stderr": True, "format": False, } @classmethod def get_target(cls, scheme, path, fragment, username, password, hostname, port, query, **kwargs): full_path = (hostname or "") + path query.update(kwargs) return MockTarget(full_path, **query) class LocalOpener(Opener): """Local filesystem opener, works with any valid system path. This is the default opener and will be used if you don't indicate which opener. examples: * file://relative/foo/bar/baz.txt (opens a relative file) * file:///home/user (opens a directory from a absolute path) * foo/bar.baz (file:// is the default opener) """ names = ["file"] allowed_kwargs = { "is_tmp": True, "format": False, } @classmethod def get_target(cls, scheme, path, fragment, username, password, hostname, port, query, **kwargs): full_path = (hostname or "") + path query.update(kwargs) return LocalTarget(full_path, **query) class S3Opener(Opener): """Opens a target stored on Amazon S3 storage examples: * s3://bucket/foo/bar.txt * s3://bucket/foo/bar.txt?aws_access_key_id=xxx&aws_secret_access_key=yyy """ names = ["s3", "s3n"] allowed_kwargs = { "format": False, "client": True, } filter_kwargs = False @classmethod def get_target(cls, scheme, path, fragment, username, password, hostname, port, query, **kwargs): query.update(kwargs) return S3Target("{scheme}://{hostname}{path}".format(scheme=scheme, hostname=hostname, path=path), **query) opener = OpenerRegistry( [ MockOpener, LocalOpener, S3Opener, ] ) OpenerTarget = opener.open ================================================ FILE: luigi/contrib/pai.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2017 Open Targets # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ MicroSoft OpenPAI Job wrapper for Luigi. "OpenPAI is an open source platform that provides complete AI model training and resource management capabilities, it is easy to extend and supports on-premise, cloud and hybrid environments in various scale." For more information about OpenPAI : https://github.com/Microsoft/pai/, this task is tested against OpenPAI 0.7.1 Requires: - requests: ``pip install requests`` Written and maintained by Liu, Dongqing (@liudongqing). """ import abc import json import logging import time from urllib.parse import urljoin import luigi logger = logging.getLogger("luigi-interface") try: import requests as rs from requests.exceptions import HTTPError except ImportError: logger.warning("requests is not installed. PaiTask requires requests.") def slot_to_dict(o): o_dict = {} for key in o.__slots__: if not key.startswith("__"): value = getattr(o, key, None) if value is not None: o_dict[key] = value return o_dict class PaiJob: """ The Open PAI job definition. Refer to here https://github.com/Microsoft/pai/blob/master/docs/job_tutorial.md :: { "jobName": String, "image": String, "authFile": String, "dataDir": String, "outputDir": String, "codeDir": String, "virtualCluster": String, "taskRoles": [ { "name": String, "taskNumber": Integer, "cpuNumber": Integer, "memoryMB": Integer, "shmMB": Integer, "gpuNumber": Integer, "portList": [ { "label": String, "beginAt": Integer, "portNumber": Integer } ], "command": String, "minFailedTaskCount": Integer, "minSucceededTaskCount": Integer } ], "gpuType": String, "retryCount": Integer } """ __slots__ = ("jobName", "image", "authFile", "dataDir", "outputDir", "codeDir", "virtualCluster", "taskRoles", "gpuType", "retryCount") def __init__(self, jobName, image, tasks): """ Initialize a Job with required fields. :param jobName: Name for the job, need to be unique :param image: URL pointing to the Docker image for all tasks in the job :param tasks: List of taskRole, one task role at least """ self.jobName = jobName self.image = image if isinstance(tasks, list) and len(tasks) != 0: self.taskRoles = tasks else: raise TypeError("you must specify one task at least.") class Port: __slots__ = ("label", "beginAt", "portNumber") def __init__(self, label, begin_at=0, port_number=1): """ The Port definition for TaskRole :param label: Label name for the port type, required :param begin_at: The port to begin with in the port type, 0 for random selection, required :param port_number: Number of ports for the specific type, required """ self.label = label self.beginAt = begin_at self.portNumber = port_number class TaskRole: __slots__ = ("name", "taskNumber", "cpuNumber", "memoryMB", "shmMB", "gpuNumber", "portList", "command", "minFailedTaskCount", "minSucceededTaskCount") def __init__(self, name, command, taskNumber=1, cpuNumber=1, memoryMB=2048, shmMB=64, gpuNumber=0, portList=[]): """ The TaskRole of PAI :param name: Name for the task role, need to be unique with other roles, required :param command: Executable command for tasks in the task role, can not be empty, required :param taskNumber: Number of tasks for the task role, no less than 1, required :param cpuNumber: CPU number for one task in the task role, no less than 1, required :param shmMB: Shared memory for one task in the task role, no more than memory size, required :param memoryMB: Memory for one task in the task role, no less than 100, required :param gpuNumber: GPU number for one task in the task role, no less than 0, required :param portList: List of portType to use, optional """ self.name = name self.command = command self.taskNumber = taskNumber self.cpuNumber = cpuNumber self.memoryMB = memoryMB self.shmMB = shmMB self.gpuNumber = gpuNumber self.portList = portList class OpenPai(luigi.Config): pai_url = luigi.Parameter(default="http://127.0.0.1:9186", description="rest server url, default is http://127.0.0.1:9186") username = luigi.Parameter(default="admin", description="your username") password = luigi.Parameter(default=None, description="your password") expiration = luigi.IntParameter(default=3600, description="expiration time in seconds") class PaiTask(luigi.Task): __POLL_TIME = 5 @property @abc.abstractmethod def name(self): """Name for the job, need to be unique, required""" return "SklearnExample" @property @abc.abstractmethod def image(self): """URL pointing to the Docker image for all tasks in the job, required""" return "openpai/pai.example.sklearn" @property @abc.abstractmethod def tasks(self): """List of taskRole, one task role at least, required""" return [] @property def auth_file_path(self): """Docker registry authentication file existing on HDFS, optional""" return None @property def data_dir(self): """Data directory existing on HDFS, optional""" return None @property def code_dir(self): """Code directory existing on HDFS, should not contain any data and should be less than 200MB, optional""" return None @property def output_dir(self): """Output directory on HDFS, $PAI_DEFAULT_FS_URI/$jobName/output will be used if not specified, optional""" return "$PAI_DEFAULT_FS_URI/{0}/output".format(self.name) @property def virtual_cluster(self): """The virtual cluster job runs on. If omitted, the job will run on default virtual cluster, optional""" return "default" @property def gpu_type(self): """Specify the GPU type to be used in the tasks. If omitted, the job will run on any gpu type, optional""" return None @property def retry_count(self): """Job retry count, no less than 0, optional""" return 0 def __init_token(self): self.__openpai = OpenPai() request_json = json.dumps({"username": self.__openpai.username, "password": self.__openpai.password, "expiration": self.__openpai.expiration}) logger.debug("Requesting token from OpenPai") response = rs.post(urljoin(self.__openpai.pai_url, "/api/v1/token"), headers={"Content-Type": "application/json"}, data=request_json) logger.debug("Get token response {0}".format(response.text)) if response.status_code != 200: msg = "Get token request failed, response is {}".format(response.text) logger.error(msg) raise Exception(msg) else: self.__token = response.json()["token"] def __init__(self, *args, **kwargs): """ :param pai_url: The rest server url of PAI clusters, default is 'http://127.0.0.1:9186'. :param token: The token used to auth the rest server of PAI. """ super(PaiTask, self).__init__(*args, **kwargs) self.__init_token() def __check_job_status(self): response = rs.get(urljoin(self.__openpai.pai_url, "/api/v1/jobs/{0}".format(self.name))) logger.debug("Check job response {0}".format(response.text)) if response.status_code == 404: msg = "Job {0} is not found".format(self.name) logger.debug(msg) raise HTTPError(msg, response=response) elif response.status_code != 200: msg = "Get job request failed, response is {}".format(response.text) logger.error(msg) raise HTTPError(msg, response=response) job_state = response.json()["jobStatus"]["state"] if job_state in ["UNKNOWN", "WAITING", "RUNNING"]: logger.debug("Job {0} is running in state {1}".format(self.name, job_state)) return False else: msg = "Job {0} finished in state {1}".format(self.name, job_state) logger.info(msg) if job_state == "SUCCEED": return True else: raise RuntimeError(msg) def run(self): job = PaiJob(self.name, self.image, self.tasks) job.virtualCluster = self.virtual_cluster job.authFile = self.auth_file_path job.codeDir = self.code_dir job.dataDir = self.data_dir job.outputDir = self.output_dir job.retryCount = self.retry_count job.gpuType = self.gpu_type request_json = json.dumps(job, default=slot_to_dict) logger.debug("Submit job request {0}".format(request_json)) response = rs.post( urljoin(self.__openpai.pai_url, "/api/v1/jobs"), headers={"Content-Type": "application/json", "Authorization": "Bearer {}".format(self.__token)}, data=request_json, ) logger.debug("Submit job response {0}".format(response.text)) # 202 is success for job submission, see https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md if response.status_code != 202: msg = "Submit job failed, response code is {0}, body is {1}".format(response.status_code, response.text) logger.error(msg) raise HTTPError(msg, response=response) while not self.__check_job_status(): time.sleep(self.__POLL_TIME) def output(self): return luigi.contrib.hdfs.HdfsTarget(self.output()) def complete(self): try: return self.__check_job_status() except HTTPError: return False except RuntimeError: return False ================================================ FILE: luigi/contrib/pig.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ Apache Pig support. Example configuration section in luigi.cfg:: [pig] # pig home directory home: /usr/share/pig """ import logging import os import select import signal import subprocess import sys import tempfile from contextlib import contextmanager import luigi from luigi import configuration logger = logging.getLogger("luigi-interface") class PigJobTask(luigi.Task): def pig_home(self): return configuration.get_config().get("pig", "home", "/usr/share/pig") def pig_command_path(self): return os.path.join(self.pig_home(), "bin/pig") def pig_env_vars(self): """ Dictionary of environment variables that should be set when running Pig. Ex:: return { 'PIG_CLASSPATH': '/your/path' } """ return {} def pig_properties(self): """ Dictionary of properties that should be set when running Pig. Example:: return { 'pig.additional.jars':'/path/to/your/jar' } """ return {} def pig_parameters(self): """ Dictionary of parameters that should be set for the Pig job. Example:: return { 'YOUR_PARAM_NAME':'Your param value' } """ return {} def pig_options(self): """ List of options that will be appended to the Pig command. Example:: return ['-x', 'local'] """ return [] def output(self): raise NotImplementedError("subclass should define output path") def pig_script_path(self): """ Return the path to the Pig script to be run. """ raise NotImplementedError("subclass should define pig_script_path") @contextmanager def _build_pig_cmd(self): opts = self.pig_options() def line(k, v): return ("%s=%s%s" % (k, v, os.linesep)).encode("utf-8") with tempfile.NamedTemporaryFile() as param_file, tempfile.NamedTemporaryFile() as prop_file: if self.pig_parameters(): items = self.pig_parameters().items() param_file.writelines(line(k, v) for (k, v) in items) param_file.flush() opts.append("-param_file") opts.append(param_file.name) if self.pig_properties(): items = self.pig_properties().items() prop_file.writelines(line(k, v) for k, v in items) prop_file.flush() opts.append("-propertyFile") opts.append(prop_file.name) cmd = [self.pig_command_path()] + opts + ["-f", self.pig_script_path()] logger.info(subprocess.list2cmdline(cmd)) yield cmd def run(self): with self._build_pig_cmd() as cmd: self.track_and_progress(cmd) def track_and_progress(self, cmd): temp_stdout = tempfile.TemporaryFile("wb") env = os.environ.copy() env["PIG_HOME"] = self.pig_home() for k, v in self.pig_env_vars().items(): env[k] = v proc = subprocess.Popen(cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env) reads = [proc.stderr.fileno(), proc.stdout.fileno()] # tracking the possible problems with this job err_lines = [] with PigRunContext(): while proc.poll() is None: ret = select.select(reads, [], []) for fd in ret[0]: if fd == proc.stderr.fileno(): line = proc.stderr.readline().decode("utf8") err_lines.append(line) if fd == proc.stdout.fileno(): line_bytes = proc.stdout.readline() temp_stdout.write(line_bytes) line = line_bytes.decode("utf8") err_line = line.lower() if err_line.find("More information at:") != -1: logger.info(err_line.split("more information at: ")[-1].strip()) if err_line.find(" - "): t = err_line.split(" - ")[-1].strip() if t != "": logger.info(t) # Read the rest + stdout err = "".join(err_lines + [an_err_line.decode("utf8") for an_err_line in proc.stderr]) if proc.returncode == 0: logger.info("Job completed successfully!") else: logger.error("Error when running script:\n%s", self.pig_script_path()) logger.error(err) raise PigJobError("Pig script failed with return value: %s" % (proc.returncode,), err=err) class PigRunContext: def __init__(self): self.job_id = None def __enter__(self): self.__old_signal = signal.getsignal(signal.SIGTERM) signal.signal(signal.SIGTERM, self.kill_job) return self def kill_job(self, captured_signal=None, stack_frame=None): if self.job_id: logger.info("Job interrupted, killing job %s", self.job_id) subprocess.call(["pig", "-e", '"kill %s"' % self.job_id]) if captured_signal is not None: # adding 128 gives the exit code corresponding to a signal sys.exit(128 + captured_signal) def __exit__(self, exc_type, exc_val, exc_tb): if exc_type is KeyboardInterrupt: self.kill_job() signal.signal(signal.SIGTERM, self.__old_signal) class PigJobError(RuntimeError): def __init__(self, message, out=None, err=None): super(PigJobError, self).__init__(message, out, err) self.message = message self.out = out self.err = err def __str__(self): info = self.message if self.out: info += "\nSTDOUT: " + str(self.out) if self.err: info += "\nSTDERR: " + str(self.err) return info ================================================ FILE: luigi/contrib/postgres.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ Implements a subclass of :py:class:`~luigi.target.Target` that writes data to Postgres. Also provides a helper task to copy data into a Postgres table. """ import datetime import logging import os import re import tempfile import luigi from luigi.contrib import rdbms logger = logging.getLogger("luigi-interface") DB_DRIVER = os.environ.get("LUIGI_PGSQL_DRIVER", "psycopg2") DB_ERROR_CODES = {} ERROR_DUPLICATE_TABLE = "duplicate_table" ERROR_UNDEFINED_TABLE = "undefined_table" dbapi = None if DB_DRIVER == "psycopg2": try: import psycopg2 as dbapi def update_error_codes(): import psycopg2.errorcodes DB_ERROR_CODES.update( { psycopg2.errorcodes.DUPLICATE_TABLE: ERROR_DUPLICATE_TABLE, psycopg2.errorcodes.UNDEFINED_TABLE: ERROR_UNDEFINED_TABLE, } ) update_error_codes() except ImportError: pass if dbapi is None or DB_DRIVER == "pg8000": try: import pg8000.core import pg8000.dbapi as dbapi # noqa: F811 # pg8000 doesn't have an error code catalog so we need to make our own # from https://www.postgresql.org/docs/8.2/errcodes-appendix.html DB_ERROR_CODES.update({"42P07": ERROR_DUPLICATE_TABLE, "42P01": ERROR_UNDEFINED_TABLE}) except ImportError: pass if dbapi is None: logger.warning("Loading postgres module without psycopg2 nor pg8000 installed. Will crash at runtime if postgres functionality is used.") def _is_pg8000_error(exception): try: return ( isinstance(exception, dbapi.DatabaseError) and isinstance(exception.args, tuple) and isinstance(exception.args[0], dict) and pg8000.core.RESPONSE_CODE in exception.args[0] ) except NameError: return False def _pg8000_connection_reset(connection): cursor = connection.cursor() if connection.autocommit: cursor.execute("DISCARD ALL") else: cursor.execute("ABORT") cursor.execute("BEGIN TRANSACTION") cursor.close() def db_error_code(exception): try: error_code = None if hasattr(exception, "pgcode"): error_code = exception.pgcode elif _is_pg8000_error(exception): error_code = exception.args[0][pg8000.core.RESPONSE_CODE] return DB_ERROR_CODES.get(error_code) except TypeError as error: error.__cause__ = exception raise error class MultiReplacer: """ Object for one-pass replace of multiple words Substituted parts will not be matched against other replace patterns, as opposed to when using multipass replace. The order of the items in the replace_pairs input will dictate replacement precedence. Constructor arguments: replace_pairs -- list of 2-tuples which hold strings to be replaced and replace string Usage: .. code-block:: python >>> replace_pairs = [("a", "b"), ("b", "c")] >>> MultiReplacer(replace_pairs)("abcd") 'bccd' >>> replace_pairs = [("ab", "x"), ("a", "x")] >>> MultiReplacer(replace_pairs)("ab") 'x' >>> replace_pairs.reverse() >>> MultiReplacer(replace_pairs)("ab") 'xb' """ # TODO: move to misc/util module def __init__(self, replace_pairs): """ Initializes a MultiReplacer instance. :param replace_pairs: list of 2-tuples which hold strings to be replaced and replace string. :type replace_pairs: tuple """ replace_list = list(replace_pairs) # make a copy in case input is iterable self._replace_dict = dict(replace_list) pattern = "|".join(re.escape(x) for x, y in replace_list) self._search_re = re.compile(pattern) def _replacer(self, match_object): # this method is used as the replace function in the re.sub below return self._replace_dict[match_object.group()] def __call__(self, search_string): # using function replacing for a per-result replace return self._search_re.sub(self._replacer, search_string) # these are the escape sequences recognized by postgres COPY # according to http://www.postgresql.org/docs/8.1/static/sql-copy.html default_escape = MultiReplacer([("\\", "\\\\"), ("\t", "\\t"), ("\n", "\\n"), ("\r", "\\r"), ("\v", "\\v"), ("\b", "\\b"), ("\f", "\\f")]) class PostgresTarget(luigi.Target): """ Target for a resource in Postgres. This will rarely have to be directly instantiated by the user. """ marker_table = luigi.configuration.get_config().get("postgres", "marker-table", "table_updates") # if not supplied, fall back to default Postgres port DEFAULT_DB_PORT = 5432 # Use DB side timestamps or client side timestamps in the marker_table use_db_timestamps = True def __init__(self, host, database, user, password, table, update_id, port=None): """ Args: host (str): Postgres server address. Possibly a host:port string. database (str): Database name user (str): Database user password (str): Password for specified user update_id (str): An identifier for this data set port (int): Postgres server port. """ if ":" in host: self.host, self.port = host.split(":") else: self.host = host self.port = port or self.DEFAULT_DB_PORT self.database = database self.user = user self.password = password self.table = table self.update_id = update_id def __str__(self): return self.table def touch(self, connection=None): """ Mark this update as complete. Important: If the marker table doesn't exist, the connection transaction will be aborted and the connection reset. Then the marker table will be created. """ self.create_marker_table() if connection is None: # TODO: test this connection = self.connect() connection.autocommit = True # if connection created here, we commit it here if self.use_db_timestamps: connection.cursor().execute( """INSERT INTO {marker_table} (update_id, target_table) VALUES (%s, %s) """.format(marker_table=self.marker_table), (self.update_id, self.table), ) else: connection.cursor().execute( """INSERT INTO {marker_table} (update_id, target_table, inserted) VALUES (%s, %s, %s); """.format(marker_table=self.marker_table), (self.update_id, self.table, datetime.datetime.now()), ) def exists(self, connection=None): if connection is None: connection = self.connect() connection.autocommit = True cursor = connection.cursor() try: cursor.execute( """SELECT 1 FROM {marker_table} WHERE update_id = %s LIMIT 1""".format(marker_table=self.marker_table), (self.update_id,), ) row = cursor.fetchone() except dbapi.DatabaseError as e: if db_error_code(e) == ERROR_UNDEFINED_TABLE: row = None else: raise return row is not None def connect(self): """ Get a DBAPI 2.0 connection object to the database where the table is. """ connection = dbapi.connect(host=self.host, port=self.port, database=self.database, user=self.user, password=self.password) connection.set_client_encoding("utf-8") return connection def create_marker_table(self): """ Create marker table if it doesn't exist. Using a separate connection since the transaction might have to be reset. """ connection = self.connect() connection.autocommit = True cursor = connection.cursor() if self.use_db_timestamps: sql = """ CREATE TABLE {marker_table} ( update_id TEXT PRIMARY KEY, target_table TEXT, inserted TIMESTAMP DEFAULT NOW()) """.format(marker_table=self.marker_table) else: sql = """ CREATE TABLE {marker_table} ( update_id TEXT PRIMARY KEY, target_table TEXT, inserted TIMESTAMP); """.format(marker_table=self.marker_table) try: cursor.execute(sql) except dbapi.DatabaseError as e: if db_error_code(e) == ERROR_DUPLICATE_TABLE: pass else: raise connection.close() def open(self, mode): raise NotImplementedError("Cannot open() PostgresTarget") class CopyToTable(rdbms.CopyToTable): """ Template task for inserting a data set into Postgres Usage: Subclass and override the required `host`, `database`, `user`, `password`, `table` and `columns` attributes. To customize how to access data from an input task, override the `rows` method with a generator that yields each row as a tuple with fields ordered according to `columns`. """ def rows(self): """ Return/yield tuples or lists corresponding to each row to be inserted. """ with self.input().open("r") as fobj: for line in fobj: yield line.strip("\n").split("\t") def map_column(self, value): """ Applied to each column of every row returned by `rows`. Default behaviour is to escape special characters and identify any self.null_values. """ if value in self.null_values: return r"\\N" else: return default_escape(str(value)) # everything below will rarely have to be overridden def output(self): """ Returns a PostgresTarget representing the inserted dataset. Normally you don't override this. """ return PostgresTarget( host=self.host, database=self.database, user=self.user, password=self.password, table=self.table, update_id=self.update_id, port=self.port ) def copy(self, cursor, file): if isinstance(self.columns[0], str): column_names = self.columns elif len(self.columns[0]) == 2: column_names = [c[0] for c in self.columns] else: raise Exception("columns must consist of column strings or (column string, type string) tuples (was %r ...)" % (self.columns[0],)) copy_sql = ("COPY {table} ({column_list}) FROM STDIN WITH (FORMAT text, NULL '{null_string}', DELIMITER '{delimiter}')").format( table=self.table, delimiter=self.column_separator, null_string=r"\\N", column_list=", ".join(column_names) ) # cursor.copy_expert is not available in pg8000 if hasattr(cursor, "copy_expert"): cursor.copy_expert(copy_sql, file) else: cursor.execute(copy_sql, stream=file) def run(self): """ Inserts data generated by rows() into target table. If the target table doesn't exist, self.create_table will be called to attempt to create the table. Normally you don't want to override this. """ if not (self.table and self.columns): raise Exception("table and columns need to be specified") connection = self.output().connect() # transform all data generated by rows() using map_column and write data # to a temporary file for import using postgres COPY tmp_dir = luigi.configuration.get_config().get("postgres", "local-tmp-dir", None) tmp_file = tempfile.TemporaryFile(dir=tmp_dir) n = 0 for row in self.rows(): n += 1 if n % 100000 == 0: logger.info("Wrote %d lines", n) rowstr = self.column_separator.join(self.map_column(val) for val in row) rowstr += "\n" tmp_file.write(rowstr.encode("utf-8")) logger.info("Done writing, importing at %s", datetime.datetime.now()) tmp_file.seek(0) # attempt to copy the data into postgres # if it fails because the target table doesn't exist # try to create it by running self.create_table for attempt in range(2): try: cursor = connection.cursor() self.init_copy(connection) self.copy(cursor, tmp_file) self.post_copy(connection) if self.enable_metadata_columns: self.post_copy_metacolumns(cursor) except dbapi.DatabaseError as e: if db_error_code(e) == ERROR_UNDEFINED_TABLE and attempt == 0: # if first attempt fails with "relation not found", try creating table logger.info("Creating table %s", self.table) # reset() is a psycopg2-specific method if hasattr(connection, "reset"): connection.reset() else: _pg8000_connection_reset(connection) self.create_table(connection) else: raise else: break # mark as complete in same transaction self.output().touch(connection) # commit and clean up connection.commit() connection.close() tmp_file.close() class PostgresQuery(rdbms.Query): """ Template task for querying a Postgres compatible database Usage: Subclass and override the required `host`, `database`, `user`, `password`, `table`, and `query` attributes. Optionally one can override the `autocommit` attribute to put the connection for the query in autocommit mode. Override the `run` method if your use case requires some action with the query result. Task instances require a dynamic `update_id`, e.g. via parameter(s), otherwise the query will only execute once To customize the query signature as recorded in the database marker table, override the `update_id` property. """ def run(self): connection = self.output().connect() connection.autocommit = self.autocommit cursor = connection.cursor() sql = self.query logger.info("Executing query from task: {name}".format(name=self.__class__)) cursor.execute(sql) # Update marker table self.output().touch(connection) # commit and close connection connection.commit() connection.close() def output(self): """ Returns a PostgresTarget representing the executed query. Normally you don't override this. """ return PostgresTarget( host=self.host, database=self.database, user=self.user, password=self.password, table=self.table, update_id=self.update_id, port=self.port ) ================================================ FILE: luigi/contrib/presto.py ================================================ import inspect import logging import re from collections import OrderedDict from contextlib import closing from enum import Enum from time import sleep import luigi from luigi.contrib import rdbms from luigi.task_register import Register logger = logging.getLogger("luigi-interface") try: from pyhive.exc import DatabaseError from pyhive.presto import Connection, Cursor except ImportError: logger.warning("pyhive[presto] is not installed.") class presto(luigi.Config): # NOQA host = luigi.Parameter(default="localhost", description="Presto host") port = luigi.IntParameter(default=8090, description="Presto port") user = luigi.Parameter(default="anonymous", description="Presto user") catalog = luigi.Parameter(default="hive", description="Default catalog") password = luigi.Parameter(default=None, description="User password") protocol = luigi.Parameter(default="https", description="Presto connection protocol") poll_interval = luigi.FloatParameter(default=1.0, description=" how often to ask the Presto REST interface for a progress update, defaults to a second") class PrestoClient: """ Helper class wrapping `pyhive.presto.Connection` for executing presto queries and tracking progress """ def __init__(self, connection, sleep_time=1): self.sleep_time = sleep_time self._connection = connection self._status = {"state": "initial"} @property def percentage_progress(self): """ :return: percentage of query overall progress """ return self._status.get("stats", {}).get("progressPercentage", 0.1) @property def info_uri(self): """ :return: query UI link """ return self._status.get("infoUri") def execute(self, query, parameters=None, mode=None): """ :param query: query to run :param parameters: parameters should be injected in the query :param mode: "fetch" - yields rows, "watch" - yields log entries :return: """ class Mode(Enum): watch = "watch" fetch = "fetch" _mode = Mode(mode) if mode else Mode.watch with closing(self._connection.cursor()) as cursor: cursor.execute(query, parameters) status = self._status while status: sleep(self.sleep_time) status = cursor.poll() if status: if _mode == Mode.watch: yield status self._status = status if _mode == Mode.fetch: for row in cursor.fetchall(): yield row class WithPrestoClient(Register): """ A metaclass for injecting `PrestoClient` as a `_client` field into a new instance of class `T` Presto connection options are taken from `T`-instance fields Fields should have the same names as in `pyhive.presto.Cursor` """ def __new__(cls, name, bases, attrs): def _client(self): def _kwargs(): """ replace to ``` (_self, *args), *_ = inspect.getfullargspec(Cursor.__init__) ``` after py2-deprecation """ args = inspect.getfullargspec(Cursor.__init__)[0][1:] for parameter in args: val = getattr(self, parameter) if val: yield parameter, val connection = Connection(**dict(_kwargs())) return PrestoClient(connection=connection) attrs.update({"_client": property(_client)}) return super(cls, WithPrestoClient).__new__(cls, name, bases, attrs) class PrestoTarget(luigi.Target): """ Target for presto-accessible tables """ def __init__(self, client, catalog, database, table, partition=None): self.catalog = catalog self.database = database self.table = table self.partition = partition self._client = client self._count = None def __str__(self): return self.table @property def _count_query(self): partition = OrderedDict(self.partition or {1: 1}) def _clauses(): for k in partition.keys(): yield "{} = %s".format(k) clauses = " AND ".join(_clauses()) query = "SELECT COUNT(*) AS cnt FROM {}.{}.{} WHERE {} LIMIT 1".format(self.catalog, self.database, self.table, clauses) params = list(partition.values()) return query, params def _table_doesnot_exist(self, exception): pattern = re.compile(r"line (\d+):(\d+): Table {}.{}.{} does not exist".format(self.catalog, self.database, self.table)) try: message = exception.message["message"] if pattern.match(message): return True finally: return False def count(self): if not self._count: """ replace to self._count, *_ = next(self._client.execute(*self.count_query, 'fetch')) after py2 deprecation """ self._count = next(self._client.execute(*self._count_query, mode="fetch"))[0] return self._count def exists(self): """ :return: `True` if given table exists and there are any rows in a given partition `False` if no rows in the partition exists or table is absent """ try: return self.count() > 0 except DatabaseError as exception: if self._table_doesnot_exist(exception): return False except Exception: raise class PrestoTask(rdbms.Query, metaclass=WithPrestoClient): """ Task for executing presto queries During its executions tracking url and percentage progress are set """ _tracking_url_set = False @property def host(self): return presto().host @property def port(self): return presto().port @property def user(self): return presto().user @property def username(self): return self.user @property def schema(self): return self.database @property def password(self): return presto().password @property def catalog(self): return presto().catalog @property def poll_interval(self): return presto().poll_interval @property def source(self): return "pyhive" @property def partition(self): return None @property def protocol(self): return "https" if self.password else presto().protocol @property def session_props(self): return None @property def requests_session(self): return None @property def requests_kwargs(self): return {"verify": False} query = None def _maybe_set_tracking_url(self): if not self._tracking_url_set: self.set_tracking_url(self._client.info_uri) self._tracking_url_set = True def _set_progress(self): self.set_progress_percentage(self._client.percentage_progress) def run(self): for _ in self._client.execute(self.query): self._maybe_set_tracking_url() self._set_progress() def output(self): return PrestoTarget( client=self._client, catalog=self.catalog, database=self.database, table=self.table, partition=self.partition, ) ================================================ FILE: luigi/contrib/prometheus_metric.py ================================================ from prometheus_client import CONTENT_TYPE_LATEST, CollectorRegistry, Counter, Gauge, generate_latest from luigi import parameter from luigi.metrics import MetricsCollector from luigi.task import Config class prometheus(Config): use_task_family_in_labels = parameter.BoolParameter(default=True, parsing=parameter.BoolParameter.EXPLICIT_PARSING) task_parameters_to_use_in_labels = parameter.ListParameter(default=()) class PrometheusMetricsCollector(MetricsCollector): def _generate_task_labels(self, task): return {label: task.family if label == "family" else task.params.get(label) for label in self.labels} def __init__(self, *args, **kwargs): super(PrometheusMetricsCollector, self).__init__() self.registry = CollectorRegistry() config = prometheus(**kwargs) self.labels = list(config.task_parameters_to_use_in_labels) if config.use_task_family_in_labels: self.labels += ["family"] if not self.labels: raise ValueError("Prometheus labels cannot be empty (see prometheus configuration)") self.task_started_counter = Counter("luigi_task_started_total", "number of started luigi tasks", self.labels, registry=self.registry) self.task_failed_counter = Counter("luigi_task_failed_total", "number of failed luigi tasks", self.labels, registry=self.registry) self.task_disabled_counter = Counter("luigi_task_disabled_total", "number of disabled luigi tasks", self.labels, registry=self.registry) self.task_done_counter = Counter("luigi_task_done_total", "number of done luigi tasks", self.labels, registry=self.registry) self.task_execution_time = Gauge("luigi_task_execution_time_seconds", "luigi task execution time in seconds", self.labels, registry=self.registry) def generate_latest(self): return generate_latest(self.registry) def handle_task_started(self, task): self.task_started_counter.labels(**self._generate_task_labels(task)).inc() self.task_execution_time.labels(**self._generate_task_labels(task)) def handle_task_failed(self, task): self.task_failed_counter.labels(**self._generate_task_labels(task)).inc() self.task_execution_time.labels(**self._generate_task_labels(task)).set(task.updated - task.time_running) def handle_task_disabled(self, task, config): self.task_disabled_counter.labels(**self._generate_task_labels(task)).inc() self.task_execution_time.labels(**self._generate_task_labels(task)).set(task.updated - task.time_running) def handle_task_done(self, task): self.task_done_counter.labels(**self._generate_task_labels(task)).inc() # time_running can be `None` if task was already complete if task.time_running is not None: self.task_execution_time.labels(**self._generate_task_labels(task)).set(task.updated - task.time_running) def configure_http_handler(self, http_handler): http_handler.set_header("Content-Type", CONTENT_TYPE_LATEST) ================================================ FILE: luigi/contrib/pyspark_runner.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # # Copyright 2012-2020 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ The pyspark program. This module will be run by spark-submit for PySparkTask jobs. The first argument is a path to the pickled instance of the PySparkTask, other arguments are the ones returned by PySparkTask.app_options() """ import abc import logging import os import pickle import sys from luigi import configuration # this prevents the modules in the directory of this script from shadowing global packages sys.path.append(sys.path.pop(0)) class _SparkEntryPoint(metaclass=abc.ABCMeta): def __init__(self, conf): self.conf = conf @abc.abstractmethod def __enter__(self): pass @abc.abstractmethod def __exit__(self, exc_type, exc_val, exc_tb): pass class SparkContextEntryPoint(_SparkEntryPoint): sc = None def __enter__(self): from pyspark import SparkContext self.sc = SparkContext(conf=self.conf) return self.sc, self.sc def __exit__(self, exc_type, exc_val, exc_tb): self.sc.stop() class SparkSessionEntryPoint(_SparkEntryPoint): spark = None def _check_major_spark_version(self): from pyspark import __version__ as spark_version major_version = int(spark_version.split(".")[0]) if major_version < 2: raise RuntimeError( """ Apache Spark {} does not support SparkSession entrypoint. Try to set 'pyspark_runner.use_spark_session' to 'False' and switch to old-style syntax """.format(spark_version) ) def __enter__(self): self._check_major_spark_version() from pyspark.sql import SparkSession self.spark = SparkSession.builder.config(conf=self.conf).enableHiveSupport().getOrCreate() return self.spark, self.spark.sparkContext def __exit__(self, exc_type, exc_val, exc_tb): self.spark.stop() class AbstractPySparkRunner(object): _entry_point_class = None def __init__(self, job, *args): # Append job directory to PYTHON_PATH to enable dynamic import # of the module in which the class resides on unpickling sys.path.append(os.path.dirname(job)) with open(job, "rb") as fd: self.job = pickle.load(fd) self.args = args def run(self): from pyspark import SparkConf conf = SparkConf() self.job.setup(conf) with self._entry_point_class(conf=conf) as (entry_point, sc): self.job.setup_remote(sc) self.job.main(entry_point, *self.args) def _pyspark_runner_with(name, entry_point_class): return type(name, (AbstractPySparkRunner,), {"_entry_point_class": entry_point_class}) PySparkRunner = _pyspark_runner_with("PySparkRunner", SparkContextEntryPoint) PySparkSessionRunner = _pyspark_runner_with("PySparkSessionRunner", SparkSessionEntryPoint) def _use_spark_session(): return bool(configuration.get_config().get("pyspark_runner", "use_spark_session", False)) def _get_runner_class(): if _use_spark_session(): return PySparkSessionRunner return PySparkRunner if __name__ == "__main__": logging.basicConfig(level=logging.WARN) _get_runner_class()(*sys.argv[1:]).run() ================================================ FILE: luigi/contrib/rdbms.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ A common module for postgres like databases, such as postgres or redshift """ from __future__ import annotations import abc import logging from typing import Any import luigi import luigi.task logger = logging.getLogger("luigi-interface") class _MetadataColumnsMixin: """Provide an additional behavior that adds columns and values to tables This mixin is used to provide an additional behavior that allow a task to add generic metadata columns to every table created for both PSQL and Redshift. Example: This is a use-case example of how this mixin could come handy and how to use it. .. code:: python class CommonMetaColumnsBehavior: def update_report_execution_date_query(self): query = "UPDATE {0} " \ "SET date_param = DATE '{1}' " \ "WHERE date_param IS NULL".format(self.table, self.date) return query @property def metadata_columns(self): if self.date: cols.append(('date_param', 'VARCHAR')) return cols @property def metadata_queries(self): queries = [self.update_created_tz_query()] if self.date: queries.append(self.update_report_execution_date_query()) return queries class RedshiftCopyCSVToTableFromS3(CommonMetaColumnsBehavior, redshift.S3CopyToTable): "We have some business override here that would only add noise to the example, so let's assume that this is only a shell." pass class UpdateTableA(RedshiftCopyCSVToTableFromS3): date = luigi.Parameter() table = 'tableA' def queries(): return [query_content_for('/queries/deduplicate_dupes.sql')] class UpdateTableB(RedshiftCopyCSVToTableFromS3): date = luigi.Parameter() table = 'tableB' """ @property def metadata_columns(self): """Returns the default metadata columns. Those columns are columns that we want each tables to have by default. """ return [] @property def metadata_queries(self): return [] @property def enable_metadata_columns(self): return False def _add_metadata_columns(self, connection): cursor = connection.cursor() for column in self.metadata_columns: if len(column) == 0: raise ValueError( "_add_metadata_columns is unable to infer column information from column {column} for {table}".format(column=column, table=self.table) ) column_name = column[0] if not self._column_exists(cursor, column_name): logger.info("Adding missing metadata column {column} to {table}".format(column=column, table=self.table)) self._add_column_to_table(cursor, column) def _column_exists(self, cursor, column_name): if "." in self.table: schema, table = self.table.split(".") query = ( "SELECT 1 AS column_exists " "FROM information_schema.columns " "WHERE table_schema = LOWER('{0}') AND table_name = LOWER('{1}') AND column_name = LOWER('{2}') LIMIT 1;".format(schema, table, column_name) ) else: query = "SELECT 1 AS column_exists FROM information_schema.columns WHERE table_name = LOWER('{0}') AND column_name = LOWER('{1}') LIMIT 1;".format( self.table, column_name ) cursor.execute(query) result = cursor.fetchone() return bool(result) def _add_column_to_table(self, cursor, column): if len(column) == 1: raise ValueError("_add_column_to_table() column type not specified for {column}".format(column=column[0])) elif len(column) == 2: query = "ALTER TABLE {table} ADD COLUMN {column};".format(table=self.table, column=" ".join(column)) elif len(column) == 3: query = "ALTER TABLE {table} ADD COLUMN {column} ENCODE {encoding};".format(table=self.table, column=" ".join(column[0:2]), encoding=column[2]) else: raise ValueError("_add_column_to_table() found no matching behavior for {column}".format(column=column)) cursor.execute(query) def post_copy_metacolumns(self, cursor): logger.info("Executing post copy metadata queries") for query in self.metadata_queries: cursor.execute(query) class CopyToTable(luigi.task.MixinNaiveBulkComplete, _MetadataColumnsMixin, luigi.Task): """ An abstract task for inserting a data set into RDBMS. Usage: Subclass and override the following attributes: * `host`, * `database`, * `user`, * `password`, * `table` * `columns` * `port` """ @property @abc.abstractmethod def host(self): return None @property @abc.abstractmethod def database(self): return None @property @abc.abstractmethod def user(self): return None @property @abc.abstractmethod def password(self): return None @property @abc.abstractmethod def table(self): return None @property def port(self): return None # specify the columns that are to be inserted (same as are returned by columns) # overload this in subclasses with the either column names of columns to import: # e.g. ['id', 'username', 'inserted'] # or tuples with column name, postgres column type strings: # e.g. [('id', 'SERIAL PRIMARY KEY'), ('username', 'VARCHAR(255)'), ('inserted', 'DATETIME')] columns: list[Any] = [] # options null_values = (None,) # container of values that should be inserted as NULL values column_separator = "\t" # how columns are separated in the file copied into postgres def create_table(self, connection): """ Override to provide code for creating the target table. By default it will be created using types (optionally) specified in columns. If overridden, use the provided connection object for setting up the table in order to create the table and insert data using the same transaction. """ if len(self.columns[0]) == 1: # only names of columns specified, no types raise NotImplementedError("create_table() not implemented for %r and columns types not specified" % self.table) elif len(self.columns[0]) == 2: # if columns is specified as (name, type) tuples coldefs = ",".join("{name} {type}".format(name=name, type=type) for name, type in self.columns) query = "CREATE TABLE {table} ({coldefs})".format(table=self.table, coldefs=coldefs) connection.cursor().execute(query) @property def update_id(self): """ This update id will be a unique identifier for this insert on this table. """ return self.task_id @abc.abstractmethod def output(self): raise NotImplementedError("This method must be overridden") def init_copy(self, connection): """ Override to perform custom queries. Any code here will be formed in the same transaction as the main copy, just prior to copying data. Example use cases include truncating the table or removing all data older than X in the database to keep a rolling window of data available in the table. """ # TODO: remove this after sufficient time so most people using the # clear_table attribtue will have noticed it doesn't work anymore if hasattr(self, "clear_table"): raise Exception("The clear_table attribute has been removed. Override init_copy instead!") if self.enable_metadata_columns: self._add_metadata_columns(connection) def post_copy(self, connection): """ Override to perform custom queries. Any code here will be formed in the same transaction as the main copy, just after copying data. Example use cases include cleansing data in temp table prior to insertion into real table. """ pass @abc.abstractmethod def copy(self, cursor, file): raise NotImplementedError("This method must be overridden") class Query(luigi.task.MixinNaiveBulkComplete, luigi.Task): """ An abstract task for executing an RDBMS query. Usage: Subclass and override the following attributes: * `host`, * `database`, * `user`, * `password`, * `table`, * `query` Optionally override: * `port`, * `autocommit` * `update_id` Subclass and override the following methods: * `run` * `output` """ @property @abc.abstractmethod def host(self): """ Host of the RDBMS. Implementation should support `hostname:port` to encode port. """ return None @property def port(self): """ Override to specify port separately from host. """ return None @property @abc.abstractmethod def database(self): return None @property @abc.abstractmethod def user(self): return None @property @abc.abstractmethod def password(self): return None @property @abc.abstractmethod def table(self): return None @property @abc.abstractmethod def query(self): return None @property def autocommit(self): return False @property def update_id(self): """ Override to create a custom marker table 'update_id' signature for Query subclass task instances """ return self.task_id @abc.abstractmethod def run(self): raise NotImplementedError("This method must be overridden") @abc.abstractmethod def output(self): """ Override with an RDBMS Target (e.g. PostgresTarget or RedshiftTarget) to record execution in a marker table """ raise NotImplementedError("This method must be overridden") ================================================ FILE: luigi/contrib/redis_store.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import datetime import logging from luigi.parameter import Parameter from luigi.target import Target logger = logging.getLogger("luigi-interface") try: import redis except ImportError: logger.warning("Loading redis_store module without redis installed. Will crash at runtime if redis_store functionality is used.") class RedisTarget(Target): """Target for a resource in Redis.""" marker_prefix = Parameter(default="luigi", config_path=dict(section="redis", name="marker-prefix")) def __init__(self, host, port, db, update_id, password=None, socket_timeout=None, expire=None): """ :param host: Redis server host :type host: str :param port: Redis server port :type port: int :param db: database index :type db: int :param update_id: an identifier for this data hash :type update_id: str :param password: a password to connect to the redis server :type password: str :param socket_timeout: client socket timeout :type socket_timeout: int :param expire: timeout before the target is deleted :type expire: int """ self.host = host self.port = port self.db = db self.password = password self.socket_timeout = socket_timeout self.update_id = update_id self.expire = expire self.redis_client = redis.StrictRedis( host=self.host, port=self.port, password=self.password, db=self.db, socket_timeout=self.socket_timeout, ) def __str__(self): return self.marker_key() def marker_key(self): """ Generate a key for the indicator hash. """ return "%s:%s" % (self.marker_prefix, self.update_id) def touch(self): """ Mark this update as complete. We index the parameters `update_id` and `date`. """ marker_key = self.marker_key() self.redis_client.hset(marker_key, "update_id", self.update_id) self.redis_client.hset(marker_key, "date", datetime.datetime.now().isoformat()) if self.expire is not None: self.redis_client.expire(marker_key, self.expire) def exists(self): """ Test, if this task has been run. """ return self.redis_client.exists(self.marker_key()) == 1 ================================================ FILE: luigi/contrib/redshift.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import abc import json import logging import os import time import luigi from luigi.contrib import postgres, rdbms from luigi.contrib.s3 import S3PathTask, S3Target logger = logging.getLogger("luigi-interface") try: import psycopg2 import psycopg2.errorcodes except ImportError: logger.warning("Loading postgres module without psycopg2 installed. Will crash at runtime if postgres functionality is used.") class _CredentialsMixin: """ This mixin is used to provide the same credential properties for AWS to all Redshift tasks. It also provides a helper method to generate the credentials string for the task. """ @property def configuration_section(self): """ Override to change the configuration section used to obtain default credentials. """ return "redshift" @property def aws_access_key_id(self): """ Override to return the key id. """ return self._get_configuration_attribute("aws_access_key_id") @property def aws_secret_access_key(self): """ Override to return the secret access key. """ return self._get_configuration_attribute("aws_secret_access_key") @property def aws_account_id(self): """ Override to return the account id. """ return self._get_configuration_attribute("aws_account_id") @property def aws_arn_role_name(self): """ Override to return the arn role name. """ return self._get_configuration_attribute("aws_arn_role_name") @property def aws_session_token(self): """ Override to return the session token. """ return self._get_configuration_attribute("aws_session_token") def _get_configuration_attribute(self, attribute): config = luigi.configuration.get_config() value = config.get(self.configuration_section, attribute, default=None) if not value: value = os.environ.get(attribute.upper(), None) return value def _credentials(self): """ Return a credential string for the provided task. If no valid credentials are set, raise a NotImplementedError. """ if self.aws_account_id and self.aws_arn_role_name: return "aws_iam_role=arn:aws:iam::{id}:role/{role}".format(id=self.aws_account_id, role=self.aws_arn_role_name) elif self.aws_access_key_id and self.aws_secret_access_key: return "aws_access_key_id={key};aws_secret_access_key={secret}{opt}".format( key=self.aws_access_key_id, secret=self.aws_secret_access_key, opt=";token={}".format(self.aws_session_token) if self.aws_session_token else "" ) else: raise NotImplementedError( "Missing Credentials. " "Ensure one of the pairs of auth args below are set " "in a configuration file, environment variables or by " "being overridden in the task: " "'aws_access_key_id' AND 'aws_secret_access_key' OR " "'aws_account_id' AND 'aws_arn_role_name'" ) class RedshiftTarget(postgres.PostgresTarget): """ Target for a resource in Redshift. Redshift is similar to postgres with a few adjustments required by redshift. """ marker_table = luigi.configuration.get_config().get("redshift", "marker-table", "table_updates") # if not supplied, fall back to default Redshift port DEFAULT_DB_PORT = 5439 use_db_timestamps = False class S3CopyToTable(rdbms.CopyToTable, _CredentialsMixin): """ Template task for inserting a data set into Redshift from s3. Usage: * Subclass and override the required attributes: * `host`, * `database`, * `user`, * `password`, * `table`, * `columns`, * `s3_load_path`. * You can also override the attributes provided by the CredentialsMixin if they are not supplied by your configuration or environment variables. """ @abc.abstractmethod def s3_load_path(self): """ Override to return the load path. """ return None @property @abc.abstractmethod def copy_options(self): """ Add extra copy options, for example: * TIMEFORMAT 'auto' * IGNOREHEADER 1 * TRUNCATECOLUMNS * IGNOREBLANKLINES * DELIMITER '\t' """ return "" @property def prune_table(self): """ Override to set equal to the name of the table which is to be pruned. Intended to be used in conjunction with prune_column and prune_date i.e. copy to temp table, prune production table to prune_column with a date greater than prune_date, then insert into production table from temp table """ return None @property def prune_column(self): """ Override to set equal to the column of the prune_table which is to be compared Intended to be used in conjunction with prune_table and prune_date i.e. copy to temp table, prune production table to prune_column with a date greater than prune_date, then insert into production table from temp table """ return None @property def prune_date(self): """ Override to set equal to the date by which prune_column is to be compared Intended to be used in conjunction with prune_table and prune_column i.e. copy to temp table, prune production table to prune_column with a date greater than prune_date, then insert into production table from temp table """ return None @property def table_attributes(self): """ Add extra table attributes, for example: DISTSTYLE KEY DISTKEY (MY_FIELD) SORTKEY (MY_FIELD_2, MY_FIELD_3) """ return "" @property def table_constraints(self): """ Add extra table constraints, for example: PRIMARY KEY (MY_FIELD, MY_FIELD_2) UNIQUE KEY (MY_FIELD_3) """ return "" @property def do_truncate_table(self): """ Return True if table should be truncated before copying new data in. """ return False def do_prune(self): """ Return True if prune_table, prune_column, and prune_date are implemented. If only a subset of prune variables are override, an exception is raised to remind the user to implement all or none. Prune (data newer than prune_date deleted) before copying new data in. """ if self.prune_table and self.prune_column and self.prune_date: return True elif self.prune_table or self.prune_column or self.prune_date: raise Exception("override zero or all prune variables") else: return False @property def table_type(self): """ Return table type (i.e. 'temp'). """ return "" @property def queries(self): """ Override to return a list of queries to be executed in order. """ return [] def truncate_table(self, connection): query = "truncate %s" % self.table cursor = connection.cursor() try: cursor.execute(query) finally: cursor.close() def prune(self, connection): query = "delete from %s where %s >= %s" % (self.prune_table, self.prune_column, self.prune_date) cursor = connection.cursor() try: cursor.execute(query) finally: cursor.close() def create_schema(self, connection): """ Will create the schema in the database """ if "." not in self.table: return query = "CREATE SCHEMA IF NOT EXISTS {schema_name};".format(schema_name=self.table.split(".")[0]) connection.cursor().execute(query) def create_table(self, connection): """ Override to provide code for creating the target table. By default it will be created using types (optionally) specified in columns. If overridden, use the provided connection object for setting up the table in order to create the table and insert data using the same transaction. """ if len(self.columns[0]) == 1: # only names of columns specified, no types raise NotImplementedError("create_table() not implemented for %r and columns types not specified" % self.table) elif len(self.columns[0]) == 2: # if columns is specified as (name, type) tuples coldefs = ",".join("{name} {type}".format(name=name, type=type) for name, type in self.columns) table_constraints = "" if self.table_constraints != "": table_constraints = ", " + self.table_constraints query = ("CREATE {type} TABLE {table} ({coldefs} {table_constraints}) {table_attributes}").format( type=self.table_type, table=self.table, coldefs=coldefs, table_constraints=table_constraints, table_attributes=self.table_attributes ) connection.cursor().execute(query) elif len(self.columns[0]) == 3: # if columns is specified as (name, type, encoding) tuples # possible column encodings: https://docs.aws.amazon.com/redshift/latest/dg/c_Compression_encodings.html coldefs = ",".join("{name} {type} ENCODE {encoding}".format(name=name, type=type, encoding=encoding) for name, type, encoding in self.columns) table_constraints = "" if self.table_constraints != "": table_constraints = "," + self.table_constraints query = ("CREATE {type} TABLE {table} ({coldefs} {table_constraints}) {table_attributes}").format( type=self.table_type, table=self.table, coldefs=coldefs, table_constraints=table_constraints, table_attributes=self.table_attributes ) connection.cursor().execute(query) else: raise ValueError("create_table() found no columns for %r" % self.table) def run(self): """ If the target table doesn't exist, self.create_table will be called to attempt to create the table. """ if not (self.table): raise Exception("table need to be specified") path = self.s3_load_path() output = self.output() connection = output.connect() cursor = connection.cursor() self.init_copy(connection) self.copy(cursor, path) self.post_copy(cursor) if self.enable_metadata_columns: self.post_copy_metacolumns(cursor) # update marker table output.touch(connection) connection.commit() # commit and clean up connection.close() def copy(self, cursor, f): """ Defines copying from s3 into redshift. If both key-based and role-based credentials are provided, role-based will be used. """ logger.info("Inserting file: %s", f) colnames = "" if self.columns and len(self.columns) > 0: colnames = ",".join([x[0] for x in self.columns]) colnames = "({})".format(colnames) cursor.execute( """ COPY {table} {colnames} from '{source}' CREDENTIALS '{creds}' {options} ;""".format(table=self.table, colnames=colnames, source=f, creds=self._credentials(), options=self.copy_options) ) def output(self): """ Returns a RedshiftTarget representing the inserted dataset. Normally you don't override this. """ return RedshiftTarget(host=self.host, database=self.database, user=self.user, password=self.password, table=self.table, update_id=self.update_id) def does_schema_exist(self, connection): """ Determine whether the schema already exists. """ if "." in self.table: query = "select 1 as schema_exists from pg_namespace where nspname = lower(%s) limit 1" else: return True cursor = connection.cursor() try: schema = self.table.split(".")[0] cursor.execute(query, [schema]) result = cursor.fetchone() return bool(result) finally: cursor.close() def does_table_exist(self, connection): """ Determine whether the table already exists. """ if "." in self.table: query = "select 1 as table_exists from information_schema.tables where table_schema = lower(%s) and table_name = lower(%s) limit 1" else: query = "select 1 as table_exists from pg_table_def where tablename = lower(%s) limit 1" cursor = connection.cursor() try: cursor.execute(query, tuple(self.table.split("."))) result = cursor.fetchone() return bool(result) finally: cursor.close() def init_copy(self, connection): """ Perform pre-copy sql - such as creating table, truncating, or removing data older than x. """ if not self.does_schema_exist(connection): logger.info("Creating schema for %s", self.table) self.create_schema(connection) if not self.does_table_exist(connection): logger.info("Creating table %s", self.table) self.create_table(connection) if self.enable_metadata_columns: self._add_metadata_columns(connection) if self.do_truncate_table: logger.info("Truncating table %s", self.table) self.truncate_table(connection) if self.do_prune(): logger.info("Removing %s older than %s from %s", self.prune_column, self.prune_date, self.prune_table) self.prune(connection) def post_copy(self, cursor): """ Performs post-copy sql - such as cleansing data, inserting into production table (if copied to temp table), etc. """ logger.info("Executing post copy queries") for query in self.queries: cursor.execute(query) def post_copy_metacolums(self, cursor): """ Performs post-copy to fill metadata columns. """ logger.info("Executing post copy metadata queries") for query in self.metadata_queries: cursor.execute(query) class S3CopyJSONToTable(S3CopyToTable, _CredentialsMixin): """ Template task for inserting a JSON data set into Redshift from s3. Usage: * Subclass and override the required attributes: * `host`, * `database`, * `user`, * `password`, * `table`, * `columns`, * `s3_load_path`, * `jsonpath`, * `copy_json_options`. * You can also override the attributes provided by the CredentialsMixin if they are not supplied by your configuration or environment variables. """ @property @abc.abstractmethod def jsonpath(self): """ Override the jsonpath schema location for the table. """ return "" @property @abc.abstractmethod def copy_json_options(self): """ Add extra copy options, for example: * GZIP * LZOP """ return "" def copy(self, cursor, f): """ Defines copying JSON from s3 into redshift. """ logger.info("Inserting file: %s", f) cursor.execute( """ COPY %s from '%s' CREDENTIALS '%s' JSON AS '%s' %s %s ;""" % (self.table, f, self._credentials(), self.jsonpath, self.copy_json_options, self.copy_options) ) class RedshiftManifestTask(S3PathTask): """ Generic task to generate a manifest file that can be used in S3CopyToTable in order to copy multiple files from your s3 folder into a redshift table at once. For full description on how to use the manifest file see http://docs.aws.amazon.com/redshift/latest/dg/loading-data-files-using-manifest.html Usage: * requires parameters * path - s3 path to the generated manifest file, including the name of the generated file to be copied into a redshift table * folder_paths - s3 paths to the folders containing files you wish to be copied Output: * generated manifest file """ # should be over ridden to point to a variety # of folders you wish to copy from folder_paths = luigi.Parameter() text_target = True def run(self): entries = [] for folder_path in self.folder_paths: s3 = S3Target(folder_path) client = s3.fs for file_name in client.list(s3.path): entries.append({"url": "%s/%s" % (folder_path, file_name), "mandatory": True}) manifest = {"entries": entries} target = self.output().open("w") dump = json.dumps(manifest) if not self.text_target: dump = dump.encode("utf8") target.write(dump) target.close() class KillOpenRedshiftSessions(luigi.Task): """ An task for killing any open Redshift sessions in a given database. This is necessary to prevent open user sessions with transactions against the table from blocking drop or truncate table commands. Usage: Subclass and override the required `host`, `database`, `user`, and `password` attributes. """ # time in seconds to wait before # reconnecting to Redshift if our session is killed too. # 30 seconds is usually fine; 60 is conservative connection_reset_wait_seconds = luigi.IntParameter(default=60) @property @abc.abstractmethod def host(self): return None @property @abc.abstractmethod def database(self): return None @property @abc.abstractmethod def user(self): return None @property @abc.abstractmethod def password(self): return None @property def update_id(self): """ This update id will be a unique identifier for this insert on this table. """ return self.task_id def output(self): """ Returns a RedshiftTarget representing the inserted dataset. Normally you don't override this. """ # uses class name as a meta-table return RedshiftTarget( host=self.host, database=self.database, user=self.user, password=self.password, table=self.__class__.__name__, update_id=self.update_id ) def run(self): """ Kill any open Redshift sessions for the given database. """ connection = self.output().connect() # kill any sessions other than ours and # internal Redshift sessions (rdsdb) query = "select pg_terminate_backend(process) from STV_SESSIONS where db_name=%s and user_name != 'rdsdb' and process != pg_backend_pid()" cursor = connection.cursor() logger.info("Killing all open Redshift sessions for database: %s", self.database) try: cursor.execute(query, (self.database,)) cursor.close() connection.commit() except psycopg2.DatabaseError as e: if e.message and "EOF" in e.message: # sometimes this operation kills the current session. # rebuild the connection. Need to pause for 30-60 seconds # before Redshift will allow us back in. connection.close() logger.info("Pausing %s seconds for Redshift to reset connection", self.connection_reset_wait_seconds) time.sleep(self.connection_reset_wait_seconds) logger.info("Reconnecting to Redshift") connection = self.output().connect() else: raise try: self.output().touch(connection) connection.commit() finally: connection.close() logger.info("Done killing all open Redshift sessions for database: %s", self.database) class RedshiftQuery(postgres.PostgresQuery): """ Template task for querying an Amazon Redshift database Usage: Subclass and override the required `host`, `database`, `user`, `password`, `table`, and `query` attributes. Override the `run` method if your use case requires some action with the query result. Task instances require a dynamic `update_id`, e.g. via parameter(s), otherwise the query will only execute once To customize the query signature as recorded in the database marker table, override the `update_id` property. """ def output(self): """ Returns a RedshiftTarget representing the executed query. Normally you don't override this. """ return RedshiftTarget(host=self.host, database=self.database, user=self.user, password=self.password, table=self.table, update_id=self.update_id) class RedshiftUnloadTask(postgres.PostgresQuery, _CredentialsMixin): """ Template task for running UNLOAD on an Amazon Redshift database Usage: Subclass and override the required `host`, `database`, `user`, `password`, `table`, and `query` attributes. Optionally, override the `autocommit` attribute to run the query in autocommit mode - this is necessary to run VACUUM for example. Override the `run` method if your use case requires some action with the query result. Task instances require a dynamic `update_id`, e.g. via parameter(s), otherwise the query will only execute once To customize the query signature as recorded in the database marker table, override the `update_id` property. You can also override the attributes provided by the CredentialsMixin if they are not supplied by your configuration or environment variables. """ @property def s3_unload_path(self): """ Override to return the load path. """ return "" @property def unload_options(self): """ Add extra or override default unload options: """ return "DELIMITER '|' ADDQUOTES GZIP ALLOWOVERWRITE PARALLEL ON" @property def unload_query(self): """ Default UNLOAD command """ return "UNLOAD ( '{query}' ) TO '{s3_unload_path}' credentials '{credentials}' {unload_options};" def run(self): connection = self.output().connect() cursor = connection.cursor() unload_query = self.unload_query.format( query=self.query().replace("'", r"\'"), s3_unload_path=self.s3_unload_path, unload_options=self.unload_options, credentials=self._credentials() ) logger.info("Executing unload query from task: {name}".format(name=self.__class__)) cursor = connection.cursor() cursor.execute(unload_query) logger.info(cursor.statusmessage) # Update marker table self.output().touch(connection) # commit and close connection connection.commit() connection.close() def output(self): """ Returns a RedshiftTarget representing the executed query. Normally you don't override this. """ return RedshiftTarget(host=self.host, database=self.database, user=self.user, password=self.password, table=self.table, update_id=self.update_id) ================================================ FILE: luigi/contrib/s3.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ Implementation of Simple Storage Service support. :py:class:`S3Target` is a subclass of the Target class to support S3 file system operations. The `boto3` library is required to use S3 targets. """ import datetime import itertools import logging import os import os.path import warnings from configparser import NoSectionError from multiprocessing.pool import ThreadPool from urllib.parse import urlsplit from luigi import configuration from luigi.format import get_default_format from luigi.parameter import OptionalParameter, Parameter from luigi.target import AtomicLocalFile, FileAlreadyExists, FileSystem, FileSystemException, FileSystemTarget, MissingParentDirectory from luigi.task import ExternalTask logger = logging.getLogger("luigi-interface") try: import botocore from boto3.s3.transfer import TransferConfig except ImportError: logger.warning("Loading S3 module without the python package boto3. Will crash at runtime if S3 functionality is used.") # two different ways of marking a directory # with a suffix in S3 S3_DIRECTORY_MARKER_SUFFIX_0 = "_$folder$" S3_DIRECTORY_MARKER_SUFFIX_1 = "/" class InvalidDeleteException(FileSystemException): pass class FileNotFoundException(FileSystemException): pass class DeprecatedBotoClientException(Exception): pass class S3Client(FileSystem): """ boto3-powered S3 client. """ _s3 = None DEFAULT_PART_SIZE = 8388608 DEFAULT_THREADS = 100 def __init__(self, aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None, **kwargs): options = self._get_s3_config() options.update(kwargs) if aws_access_key_id: options["aws_access_key_id"] = aws_access_key_id if aws_secret_access_key: options["aws_secret_access_key"] = aws_secret_access_key if aws_session_token: options["aws_session_token"] = aws_session_token self._options = options @property def s3(self): # only import boto3 when needed to allow top-lvl s3 module import import boto3 options = dict(self._options) if self._s3: return self._s3 aws_access_key_id = options.get("aws_access_key_id") aws_secret_access_key = options.get("aws_secret_access_key") # Removing key args would break backwards compatibility role_arn = options.get("aws_role_arn") role_session_name = options.get("aws_role_session_name") # In case the aws_session_token is provided use it aws_session_token = options.get("aws_session_token") if role_arn and role_session_name: sts_client = boto3.client("sts") assumed_role = sts_client.assume_role(RoleArn=role_arn, RoleSessionName=role_session_name) aws_secret_access_key = assumed_role["Credentials"].get("SecretAccessKey") aws_access_key_id = assumed_role["Credentials"].get("AccessKeyId") aws_session_token = assumed_role["Credentials"].get("SessionToken") logger.debug("using aws credentials via assumed role {} as defined in luigi config".format(role_session_name)) for key in ["aws_access_key_id", "aws_secret_access_key", "aws_role_session_name", "aws_role_arn", "aws_session_token"]: if key in options: options.pop(key) # At this stage, if no credentials provided, boto3 would handle their resolution for us # For finding out about the order in which it tries to find these credentials # please see here details # http://boto3.readthedocs.io/en/latest/guide/configuration.html#configuring-credentials if not (aws_access_key_id and aws_secret_access_key): logger.debug("no credentials provided, delegating credentials resolution to boto3") try: self._s3 = boto3.resource( "s3", aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, aws_session_token=aws_session_token, **options ) except TypeError as e: logger.error(e.args[0]) if "got an unexpected keyword argument" in e.args[0]: raise DeprecatedBotoClientException("Now using boto3. Check that you're passing the correct arguments") raise return self._s3 @s3.setter def s3(self, value): self._s3 = value def exists(self, path): """ Does provided path exist on S3? """ (bucket, key) = self._path_to_bucket_and_key(path) # root always exists if self._is_root(key): return True # file if self._exists(bucket, key): return True if self.isdir(path): return True logger.debug("Path %s does not exist", path) return False def remove(self, path, recursive=True): """ Remove a file or directory from S3. :param path: File or directory to remove :param recursive: Boolean indicator to remove object and children :return: Boolean indicator denoting success of the removal of 1 or more files """ if not self.exists(path): logger.debug("Could not delete %s; path does not exist", path) return False (bucket, key) = self._path_to_bucket_and_key(path) s3_bucket = self.s3.Bucket(bucket) # root if self._is_root(key): raise InvalidDeleteException("Cannot delete root of bucket at path %s" % path) # file if self._exists(bucket, key): self.s3.meta.client.delete_object(Bucket=bucket, Key=key) logger.debug("Deleting %s from bucket %s", key, bucket) return True if self.isdir(path) and not recursive: raise InvalidDeleteException("Path %s is a directory. Must use recursive delete" % path) delete_key_list = [{"Key": obj.key} for obj in s3_bucket.objects.filter(Prefix=self._add_path_delimiter(key))] # delete the directory marker file if it exists if self._exists(bucket, "{}{}".format(key, S3_DIRECTORY_MARKER_SUFFIX_0)): delete_key_list.append({"Key": "{}{}".format(key, S3_DIRECTORY_MARKER_SUFFIX_0)}) if len(delete_key_list) > 0: n = 1000 for i in range(0, len(delete_key_list), n): self.s3.meta.client.delete_objects(Bucket=bucket, Delete={"Objects": delete_key_list[i : i + n]}) return True return False def move(self, source_path, destination_path, **kwargs): """ Rename/move an object from one S3 location to another. :param source_path: The `s3://` path of the directory or key to copy from :param destination_path: The `s3://` path of the directory or key to copy to :param kwargs: Keyword arguments are passed to the boto3 function `copy` """ self.copy(source_path, destination_path, **kwargs) self.remove(source_path) def get_key(self, path): """ Returns the object summary at the path """ (bucket, key) = self._path_to_bucket_and_key(path) if self._exists(bucket, key): return self.s3.ObjectSummary(bucket, key) def put(self, local_path, destination_s3_path, **kwargs): """ Put an object stored locally to an S3 path. :param local_path: Path to source local file :param destination_s3_path: URL for target S3 location :param kwargs: Keyword arguments are passed to the boto function `put_object` """ self._check_deprecated_argument(**kwargs) # put the file self.put_multipart(local_path, destination_s3_path, **kwargs) def put_string(self, content, destination_s3_path, **kwargs): """ Put a string to an S3 path. :param content: Data str :param destination_s3_path: URL for target S3 location :param kwargs: Keyword arguments are passed to the boto3 function `put_object` """ self._check_deprecated_argument(**kwargs) (bucket, key) = self._path_to_bucket_and_key(destination_s3_path) # put the file self.s3.meta.client.put_object(Key=key, Bucket=bucket, Body=content, **kwargs) def put_multipart(self, local_path, destination_s3_path, part_size=DEFAULT_PART_SIZE, **kwargs): """ Put an object stored locally to an S3 path using S3 multi-part upload (for files > 8Mb). :param local_path: Path to source local file :param destination_s3_path: URL for target S3 location :param part_size: Part size in bytes. Default: 8388608 (8MB) :param kwargs: Keyword arguments are passed to the boto function `upload_fileobj` as ExtraArgs """ self._check_deprecated_argument(**kwargs) from boto3.s3.transfer import TransferConfig # default part size for boto3 is 8Mb, changing it to fit part_size # provided as a parameter transfer_config = TransferConfig(multipart_chunksize=part_size) (bucket, key) = self._path_to_bucket_and_key(destination_s3_path) self.s3.meta.client.upload_fileobj(Fileobj=open(local_path, "rb"), Bucket=bucket, Key=key, Config=transfer_config, ExtraArgs=kwargs) def copy(self, source_path, destination_path, threads=DEFAULT_THREADS, start_time=None, end_time=None, part_size=DEFAULT_PART_SIZE, **kwargs): """ Copy object(s) from one S3 location to another. Works for individual keys or entire directories. When files are larger than `part_size`, multipart uploading will be used. :param source_path: The `s3://` path of the directory or key to copy from :param destination_path: The `s3://` path of the directory or key to copy to :param threads: Optional argument to define the number of threads to use when copying (min: 3 threads) :param start_time: Optional argument to copy files with modified dates after start_time :param end_time: Optional argument to copy files with modified dates before end_time :param part_size: Part size in bytes :param kwargs: Keyword arguments are passed to the boto function `copy` as ExtraArgs :returns tuple (number_of_files_copied, total_size_copied_in_bytes) """ # don't allow threads to be less than 3 threads = 3 if threads < 3 else threads if self.isdir(source_path): return self._copy_dir(source_path, destination_path, threads=threads, start_time=start_time, end_time=end_time, part_size=part_size, **kwargs) # If the file isn't a directory just perform a simple copy else: return self._copy_file(source_path, destination_path, threads=threads, part_size=part_size, **kwargs) def _copy_file(self, source_path, destination_path, threads=DEFAULT_THREADS, part_size=DEFAULT_PART_SIZE, **kwargs): src_bucket, src_key = self._path_to_bucket_and_key(source_path) dst_bucket, dst_key = self._path_to_bucket_and_key(destination_path) transfer_config = TransferConfig(max_concurrency=threads, multipart_chunksize=part_size) item = self.get_key(source_path) copy_source = {"Bucket": src_bucket, "Key": src_key} self.s3.meta.client.copy(copy_source, dst_bucket, dst_key, Config=transfer_config, ExtraArgs=kwargs) return 1, item.size def _copy_dir(self, source_path, destination_path, threads=DEFAULT_THREADS, start_time=None, end_time=None, part_size=DEFAULT_PART_SIZE, **kwargs): start = datetime.datetime.now() copy_jobs = [] management_pool = ThreadPool(processes=threads) transfer_config = TransferConfig(max_concurrency=threads, multipart_chunksize=part_size) src_bucket, src_key = self._path_to_bucket_and_key(source_path) dst_bucket, dst_key = self._path_to_bucket_and_key(destination_path) src_prefix = self._add_path_delimiter(src_key) dst_prefix = self._add_path_delimiter(dst_key) key_path_len = len(src_prefix) total_size_bytes = 0 total_keys = 0 for item in self.list(source_path, start_time=start_time, end_time=end_time, return_key=True): path = item.key[key_path_len:] # prevents copy attempt of empty key in folder if path != "" and path != "/": total_keys += 1 total_size_bytes += item.size copy_source = {"Bucket": src_bucket, "Key": src_prefix + path} the_kwargs = {"Config": transfer_config, "ExtraArgs": kwargs} job = management_pool.apply_async(self.s3.meta.client.copy, args=(copy_source, dst_bucket, dst_prefix + path), kwds=the_kwargs) copy_jobs.append(job) # Wait for the pools to finish scheduling all the copies management_pool.close() management_pool.join() # Raise any errors encountered in any of the copy processes for result in copy_jobs: result.get() end = datetime.datetime.now() duration = end - start logger.info("%s : Complete : %s total keys copied in %s" % (datetime.datetime.now(), total_keys, duration)) return total_keys, total_size_bytes def get(self, s3_path, destination_local_path): """ Get an object stored in S3 and write it to a local path. """ (bucket, key) = self._path_to_bucket_and_key(s3_path) # download the file self.s3.meta.client.download_file(bucket, key, destination_local_path) def get_as_bytes(self, s3_path): """ Get the contents of an object stored in S3 as bytes :param s3_path: URL for target S3 location :return: File contents as pure bytes """ (bucket, key) = self._path_to_bucket_and_key(s3_path) obj = self.s3.Object(bucket, key) contents = obj.get()["Body"].read() return contents def get_as_string(self, s3_path, encoding="utf-8"): """ Get the contents of an object stored in S3 as string. :param s3_path: URL for target S3 location :param encoding: Encoding to decode bytes to string :return: File contents as a string """ content = self.get_as_bytes(s3_path) return content.decode(encoding) def isdir(self, path): """ Is the parameter S3 path a directory? """ (bucket, key) = self._path_to_bucket_and_key(path) s3_bucket = self.s3.Bucket(bucket) # root is a directory if self._is_root(key): return True for suffix in (S3_DIRECTORY_MARKER_SUFFIX_0, S3_DIRECTORY_MARKER_SUFFIX_1): try: self.s3.meta.client.get_object(Bucket=bucket, Key=key + suffix) except botocore.exceptions.ClientError as e: if e.response["Error"]["Code"] not in ["NoSuchKey", "404"]: raise else: return True # files with this prefix key_path = self._add_path_delimiter(key) s3_bucket_list_result = list(itertools.islice(s3_bucket.objects.filter(Prefix=key_path), 1)) if s3_bucket_list_result: return True return False is_dir = isdir # compatibility with old version. def mkdir(self, path, parents=True, raise_if_exists=False): if raise_if_exists and self.isdir(path): raise FileAlreadyExists() bucket, key = self._path_to_bucket_and_key(path) if self._is_root(key): # isdir raises if the bucket doesn't exist; nothing to do here. return path = self._add_path_delimiter(path) if not parents and not self.isdir(os.path.dirname(path)): raise MissingParentDirectory() return self.put_string("", path) def listdir(self, path, start_time=None, end_time=None, return_key=False): """ Get an iterable with S3 folder contents. Iterable contains absolute paths for which queried path is a prefix. :param path: URL for target S3 location :param start_time: Optional argument to list files with modified (offset aware) datetime after start_time :param end_time: Optional argument to list files with modified (offset aware) datetime before end_time :param return_key: Optional argument, when set to True will return boto3's ObjectSummary (instead of the filename) """ (bucket, key) = self._path_to_bucket_and_key(path) # grab and validate the bucket s3_bucket = self.s3.Bucket(bucket) key_path = self._add_path_delimiter(key) key_path_len = len(key_path) for item in s3_bucket.objects.filter(Prefix=key_path): last_modified_date = item.last_modified if ( # neither are defined, list all (not start_time and not end_time) or # start defined, after start (start_time and not end_time and start_time < last_modified_date) or # end defined, prior to end (not start_time and end_time and last_modified_date < end_time) or (start_time and end_time and start_time < last_modified_date < end_time) # both defined, between ): if return_key: yield item else: yield self._add_path_delimiter(path) + item.key[key_path_len:] def list(self, path, start_time=None, end_time=None, return_key=False): # backwards compat """ Get an iterable with S3 folder contents. Iterable contains paths relative to queried path. :param path: URL for target S3 location :param start_time: Optional argument to list files with modified (offset aware) datetime after start_time :param end_time: Optional argument to list files with modified (offset aware) datetime before end_time :param return_key: Optional argument, when set to True will return boto3's ObjectSummary (instead of the filename) """ key_path_len = len(self._add_path_delimiter(path)) for item in self.listdir(path, start_time=start_time, end_time=end_time, return_key=return_key): if return_key: yield item else: yield item[key_path_len:] @staticmethod def _get_s3_config(key=None): defaults = dict(configuration.get_config().defaults()) try: config = dict(configuration.get_config().items("s3")) except (NoSectionError, KeyError): return {} # So what ports etc can be read without us having to specify all dtypes for k, v in config.items(): try: config[k] = int(v) except ValueError: pass if key: return config.get(key) section_only = {k: v for k, v in config.items() if k not in defaults or v != defaults[k]} return section_only @staticmethod def _path_to_bucket_and_key(path): (scheme, netloc, path, query, fragment) = urlsplit(path, allow_fragments=False) question_mark_plus_query = "?" + query if query else "" path_without_initial_slash = path[1:] + question_mark_plus_query return netloc, path_without_initial_slash @staticmethod def _is_root(key): return (len(key) == 0) or (key == "/") @staticmethod def _add_path_delimiter(key): return key if key[-1:] == "/" or key == "" else key + "/" @staticmethod def _check_deprecated_argument(**kwargs): """ If `encrypt_key` or `host` is part of the arguments raise an exception :return: None """ if "encrypt_key" in kwargs: raise DeprecatedBotoClientException("encrypt_key deprecated in boto3. Please refer to boto3 documentation for encryption details.") if "host" in kwargs: raise DeprecatedBotoClientException( "host keyword deprecated and is replaced by region_name in boto3.\n" "example: region_name=us-west-1\n" "For region names, refer to the amazon S3 region documentation\n" "https://docs.aws.amazon.com/general/latest/gr/rande.html#s3_region" ) def _exists(self, bucket, key): try: self.s3.Object(bucket, key).load() except botocore.exceptions.ClientError as e: if e.response["Error"]["Code"] in ["NoSuchKey", "404"]: return False else: raise return True class AtomicS3File(AtomicLocalFile): """ An S3 file that writes to a temp file and puts to S3 on close. :param kwargs: Keyword arguments are passed to the boto function `initiate_multipart_upload` """ def __init__(self, path, s3_client, **kwargs): self.s3_client = s3_client super(AtomicS3File, self).__init__(path) self.s3_options = kwargs def move_to_final_destination(self): self.s3_client.put_multipart(self.tmp_path, self.path, **self.s3_options) class ReadableS3File: def __init__(self, s3_key): self.s3_key = s3_key.get()["Body"] self.buffer = [] self.closed = False self.finished = False def read(self, size=None): f = self.s3_key.read(size) return f def close(self): self.s3_key.close() self.closed = True def __del__(self): self.close() def __exit__(self, exc_type, exc, traceback): self.close() def __enter__(self): return self def _add_to_buffer(self, line): self.buffer.append(line) def _flush_buffer(self): output = b"".join(self.buffer) self.buffer = [] return output def readable(self): return True def writable(self): return False def seekable(self): return False def __iter__(self): key_iter = self.s3_key.__iter__() has_next = True while has_next: try: # grab the next chunk chunk = next(key_iter) # split on newlines, preserving the newline for line in chunk.splitlines(True): if not line.endswith(os.linesep): # no newline, so store in buffer self._add_to_buffer(line) else: # newline found, send it out if self.buffer: self._add_to_buffer(line) yield self._flush_buffer() else: yield line except StopIteration: # send out anything we have left in the buffer output = self._flush_buffer() if output: yield output has_next = False self.close() class S3Target(FileSystemTarget): """ Target S3 file object :param kwargs: Keyword arguments are passed to the boto function `initiate_multipart_upload` """ fs = None def __init__(self, path, format=None, client=None, **kwargs): super(S3Target, self).__init__(path) if format is None: format = get_default_format() self.path = path self.format = format self.fs = client or S3Client() self.s3_options = kwargs def open(self, mode="r"): if mode not in ("r", "w"): raise ValueError("Unsupported open mode '%s'" % mode) if mode == "r": s3_key = self.fs.get_key(self.path) if not s3_key: raise FileNotFoundException("Could not find file at %s" % self.path) fileobj = ReadableS3File(s3_key) return self.format.pipe_reader(fileobj) else: return self.format.pipe_writer(AtomicS3File(self.path, self.fs, **self.s3_options)) class S3FlagTarget(S3Target): """ Defines a target directory with a flag-file (defaults to `_SUCCESS`) used to signify job success. This checks for two things: * the path exists (just like the S3Target) * the _SUCCESS file exists within the directory. Because Hadoop outputs into a directory and not a single file, the path is assumed to be a directory. This is meant to be a handy alternative to AtomicS3File. The AtomicFile approach can be burdensome for S3 since there are no directories, per se. If we have 1,000,000 output files, then we have to rename 1,000,000 objects. """ fs = None def __init__(self, path, format=None, client=None, flag="_SUCCESS"): """ Initializes a S3FlagTarget. :param path: the directory where the files are stored. :type path: str :param client: :type client: :param flag: :type flag: str """ if format is None: format = get_default_format() if path[-1] != "/": raise ValueError("S3FlagTarget requires the path to be to a directory. It must end with a slash ( / ).") super(S3FlagTarget, self).__init__(path, format, client) self.flag = flag def exists(self): hadoopSemaphore = self.path + self.flag return self.fs.exists(hadoopSemaphore) class S3EmrTarget(S3FlagTarget): """ Deprecated. Use :py:class:`S3FlagTarget` """ def __init__(self, *args, **kwargs): warnings.warn("S3EmrTarget is deprecated. Please use S3FlagTarget") super(S3EmrTarget, self).__init__(*args, **kwargs) class S3PathTask(ExternalTask): """ A external task that to require existence of a path in S3. """ path = Parameter() def output(self): return S3Target(self.path) class S3EmrTask(ExternalTask): """ An external task that requires the existence of EMR output in S3. """ path = Parameter() def output(self): return S3EmrTarget(self.path) class S3FlagTask(ExternalTask): """ An external task that requires the existence of EMR output in S3. """ path = Parameter() flag = OptionalParameter(default=None) def output(self): return S3FlagTarget(self.path, flag=self.flag) ================================================ FILE: luigi/contrib/salesforce.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import abc import csv import logging import re import tempfile import time import warnings import xml.etree.ElementTree as ET from collections import OrderedDict from urllib.parse import urlsplit import luigi from luigi import Task logger = logging.getLogger("luigi-interface") try: import requests except ImportError: logger.warning("This module requires the python package 'requests'.") def get_soql_fields(soql): """ Gets queried columns names. """ soql_fields = re.search("(?<=select)(?s)(.*)(?=from)", soql, re.IGNORECASE) # get fields soql_fields = re.sub(" ", "", soql_fields.group()) # remove extra spaces soql_fields = re.sub("\t", "", soql_fields) # remove tabs fields = re.split(",|\n|\r|", soql_fields) # split on commas and newlines fields = [field for field in fields if field != ""] # remove empty strings return fields def ensure_utf(value): return value.encode("utf-8") if isinstance(value, str) else value def parse_results(fields, data): """ Traverses ordered dictionary, calls _traverse_results() to recursively read into the dictionary depth of data """ master = [] for record in data["records"]: # for each 'record' in response row = [None] * len(fields) # create null list the length of number of columns for obj, value in record.items(): # for each obj in record if not isinstance(value, (dict, list, tuple)): # if not data structure if obj in fields: row[fields.index(obj)] = ensure_utf(value) elif isinstance(value, dict) and obj != "attributes": # traverse down into object path = obj _traverse_results(value, fields, row, path) master.append(row) return master def _traverse_results(value, fields, row, path): """ Helper method for parse_results(). Traverses through ordered dict and recursively calls itself when encountering a dictionary """ for f, v in value.items(): # for each item in obj field_name = "{path}.{name}".format(path=path, name=f) if path else f if not isinstance(v, (dict, list, tuple)): # if not data structure if field_name in fields: row[fields.index(field_name)] = ensure_utf(v) elif isinstance(v, dict) and f != "attributes": # it is a dict _traverse_results(v, fields, row, field_name) class salesforce(luigi.Config): """ Config system to get config vars from 'salesforce' section in configuration file. Did not include sandbox_name here, as the user may have multiple sandboxes. """ username = luigi.Parameter(default="") password = luigi.Parameter(default="") security_token = luigi.Parameter(default="") # sandbox token sb_security_token = luigi.Parameter(default="") class QuerySalesforce(Task): @property @abc.abstractmethod def object_name(self): """ Override to return the SF object we are querying. Must have the SF "__c" suffix if it is a customer object. """ return None @property def use_sandbox(self): """ Override to specify use of SF sandbox. True iff we should be uploading to a sandbox environment instead of the production organization. """ return False @property def sandbox_name(self): """Override to specify the sandbox name if it is intended to be used.""" return None @property @abc.abstractmethod def soql(self): """Override to return the raw string SOQL or the path to it.""" return None @property def is_soql_file(self): """Override to True if soql property is a file path.""" return False @property def content_type(self): """ Override to use a different content type. Salesforce allows XML, CSV, ZIP_CSV, or ZIP_XML. Defaults to CSV. """ return "CSV" def run(self): if self.use_sandbox and not self.sandbox_name: raise Exception("Parameter sf_sandbox_name must be provided when uploading to a Salesforce Sandbox") sf = SalesforceAPI(salesforce().username, salesforce().password, salesforce().security_token, salesforce().sb_security_token, self.sandbox_name) job_id = sf.create_operation_job("query", self.object_name, content_type=self.content_type) logger.info("Started query job %s in salesforce for object %s" % (job_id, self.object_name)) batch_id = "" msg = "" try: if self.is_soql_file: with open(self.soql, "r") as infile: self.soql = infile.read() batch_id = sf.create_batch(job_id, self.soql, self.content_type) logger.info("Creating new batch %s to query: %s for job: %s." % (batch_id, self.object_name, job_id)) status = sf.block_on_batch(job_id, batch_id) if status["state"].lower() == "failed": msg = "Batch failed with message: %s" % status["state_message"] logger.error(msg) # don't raise exception if it's b/c of an included relationship # normal query will execute (with relationship) after bulk job is closed if "foreign key relationships not supported" not in status["state_message"].lower(): raise Exception(msg) else: result_ids = sf.get_batch_result_ids(job_id, batch_id) # If there's only one result, just download it, otherwise we need to merge the resulting downloads if len(result_ids) == 1: data = sf.get_batch_result(job_id, batch_id, result_ids[0]) with open(self.output().path, "wb") as outfile: outfile.write(data) else: # Download each file to disk, and then merge into one. # Preferring to do it this way so as to minimize memory consumption. for i, result_id in enumerate(result_ids): logger.info("Downloading batch result %s for batch: %s and job: %s" % (result_id, batch_id, job_id)) with open("%s.%d" % (self.output().path, i), "wb") as outfile: outfile.write(sf.get_batch_result(job_id, batch_id, result_id)) logger.info("Merging results of batch %s" % batch_id) self.merge_batch_results(result_ids) finally: logger.info("Closing job %s" % job_id) sf.close_job(job_id) if "state_message" in status and "foreign key relationships not supported" in status["state_message"].lower(): logger.info("Retrying with REST API query") data_file = sf.query_all(self.soql) reader = csv.reader(data_file) with open(self.output().path, "wb") as outfile: writer = csv.writer(outfile, dialect="excel") for row in reader: writer.writerow(row) def merge_batch_results(self, result_ids): """ Merges the resulting files of a multi-result batch bulk query. """ outfile = open(self.output().path, "w") if self.content_type.lower() == "csv": for i, result_id in enumerate(result_ids): with open("%s.%d" % (self.output().path, i), "r") as f: header = f.readline() if i == 0: outfile.write(header) for line in f: outfile.write(line) else: raise Exception("Batch result merging not implemented for %s" % self.content_type) outfile.close() class SalesforceAPI: """ Class used to interact with the SalesforceAPI. Currently provides only the methods necessary for performing a bulk upload operation. """ API_VERSION = 34.0 SOAP_NS = "{urn:partner.soap.sforce.com}" API_NS = "{http://www.force.com/2009/06/asyncapi/dataload}" def __init__(self, username, password, security_token, sb_token=None, sandbox_name=None): self.username = username self.password = password self.security_token = security_token self.sb_security_token = sb_token self.sandbox_name = sandbox_name if self.sandbox_name: self.username += ".%s" % self.sandbox_name self.session_id = None self.server_url = None self.hostname = None def start_session(self): """ Starts a Salesforce session and determines which SF instance to use for future requests. """ if self.has_active_session(): raise Exception("Session already in progress.") response = requests.post(self._get_login_url(), headers=self._get_login_headers(), data=self._get_login_xml()) response.raise_for_status() root = ET.fromstring(response.text) for e in root.iter("%ssessionId" % self.SOAP_NS): if self.session_id: raise Exception("Invalid login attempt. Multiple session ids found.") self.session_id = e.text for e in root.iter("%sserverUrl" % self.SOAP_NS): if self.server_url: raise Exception("Invalid login attempt. Multiple server urls found.") self.server_url = e.text if not self.has_active_session(): raise Exception("Invalid login attempt resulted in null sessionId [%s] and/or serverUrl [%s]." % (self.session_id, self.server_url)) self.hostname = urlsplit(self.server_url).hostname def has_active_session(self): return self.session_id and self.server_url def query(self, query, **kwargs): """ Return the result of a Salesforce SOQL query as a dict decoded from the Salesforce response JSON payload. :param query: the SOQL query to send to Salesforce, e.g. "SELECT id from Lead WHERE email = 'a@b.com'" """ params = {"q": query} response = requests.get(self._get_norm_query_url(), headers=self._get_rest_headers(), params=params, **kwargs) if response.status_code != requests.codes.ok: raise Exception(response.content) return response.json() def query_more(self, next_records_identifier, identifier_is_url=False, **kwargs): """ Retrieves more results from a query that returned more results than the batch maximum. Returns a dict decoded from the Salesforce response JSON payload. :param next_records_identifier: either the Id of the next Salesforce object in the result, or a URL to the next record in the result. :param identifier_is_url: True if `next_records_identifier` should be treated as a URL, False if `next_records_identifer` should be treated as an Id. """ if identifier_is_url: # Don't use `self.base_url` here because the full URI is provided url = "https://{instance}{next_record_url}".format(instance=self.hostname, next_record_url=next_records_identifier) else: url = self._get_norm_query_url() + "{next_record_id}" url = url.format(next_record_id=next_records_identifier) response = requests.get(url, headers=self._get_rest_headers(), **kwargs) response.raise_for_status() return response.json() def query_all(self, query, **kwargs): """ Returns the full set of results for the `query`. This is a convenience wrapper around `query(...)` and `query_more(...)`. The returned dict is the decoded JSON payload from the final call to Salesforce, but with the `totalSize` field representing the full number of results retrieved and the `records` list representing the full list of records retrieved. :param query: the SOQL query to send to Salesforce, e.g. `SELECT Id FROM Lead WHERE Email = "waldo@somewhere.com"` """ # Make the initial query to Salesforce response = self.query(query, **kwargs) # get fields fields = get_soql_fields(query) # put fields and first page of results into a temp list to be written to TempFile tmp_list = [fields] tmp_list.extend(parse_results(fields, response)) tmp_dir = luigi.configuration.get_config().get("salesforce", "local-tmp-dir", None) tmp_file = tempfile.TemporaryFile(mode="a+b", dir=tmp_dir) writer = csv.writer(tmp_file) writer.writerows(tmp_list) # The number of results might have exceeded the Salesforce batch limit # so check whether there are more results and retrieve them if so. length = len(response["records"]) while not response["done"]: response = self.query_more(response["nextRecordsUrl"], identifier_is_url=True, **kwargs) writer.writerows(parse_results(fields, response)) length += len(response["records"]) if not length % 10000: logger.info("Requested {0} lines...".format(length)) logger.info("Requested a total of {0} lines.".format(length)) tmp_file.seek(0) return tmp_file # Generic Rest Function def restful(self, path, params): """ Allows you to make a direct REST call if you know the path Arguments: :param path: The path of the request. Example: sobjects/User/ABC123/password' :param params: dict of parameters to pass to the path """ url = self._get_norm_base_url() + path response = requests.get(url, headers=self._get_rest_headers(), params=params) if response.status_code != 200: raise Exception(response) json_result = response.json(object_pairs_hook=OrderedDict) if len(json_result) == 0: return None else: return json_result def create_operation_job(self, operation, obj, external_id_field_name=None, content_type=None): """ Creates a new SF job that for doing any operation (insert, upsert, update, delete, query) :param operation: delete, insert, query, upsert, update, hardDelete. Must be lowercase. :param obj: Parent SF object :param external_id_field_name: Optional. """ if not self.has_active_session(): self.start_session() response = requests.post( self._get_create_job_url(), headers=self._get_create_job_headers(), data=self._get_create_job_xml(operation, obj, external_id_field_name, content_type), ) response.raise_for_status() root = ET.fromstring(response.text) job_id = root.find("%sid" % self.API_NS).text return job_id def get_job_details(self, job_id): """ Gets all details for existing job :param job_id: job_id as returned by 'create_operation_job(...)' :return: job info as xml """ response = requests.get(self._get_job_details_url(job_id)) response.raise_for_status() return response def abort_job(self, job_id): """ Abort an existing job. When a job is aborted, no more records are processed. Changes to data may already have been committed and aren't rolled back. :param job_id: job_id as returned by 'create_operation_job(...)' :return: abort response as xml """ response = requests.post(self._get_abort_job_url(job_id), headers=self._get_abort_job_headers(), data=self._get_abort_job_xml()) response.raise_for_status() return response def close_job(self, job_id): """ Closes job :param job_id: job_id as returned by 'create_operation_job(...)' :return: close response as xml """ if not job_id or not self.has_active_session(): raise Exception("Can not close job without valid job_id and an active session.") response = requests.post(self._get_close_job_url(job_id), headers=self._get_close_job_headers(), data=self._get_close_job_xml()) response.raise_for_status() return response def create_batch(self, job_id, data, file_type): """ Creates a batch with either a string of data or a file containing data. If a file is provided, this will pull the contents of the file_target into memory when running. That shouldn't be a problem for any files that meet the Salesforce single batch upload size limit (10MB) and is done to ensure compressed files can be uploaded properly. :param job_id: job_id as returned by 'create_operation_job(...)' :param data: :return: Returns batch_id """ if not job_id or not self.has_active_session(): raise Exception("Can not create a batch without a valid job_id and an active session.") headers = self._get_create_batch_content_headers(file_type) headers["Content-Length"] = str(len(data)) response = requests.post(self._get_create_batch_url(job_id), headers=headers, data=data) response.raise_for_status() root = ET.fromstring(response.text) batch_id = root.find("%sid" % self.API_NS).text return batch_id def block_on_batch(self, job_id, batch_id, sleep_time_seconds=5, max_wait_time_seconds=-1): """ Blocks until @batch_id is completed or failed. :param job_id: :param batch_id: :param sleep_time_seconds: :param max_wait_time_seconds: """ if not job_id or not batch_id or not self.has_active_session(): raise Exception("Can not block on a batch without a valid batch_id, job_id and an active session.") start_time = time.time() status = {} while max_wait_time_seconds < 0 or time.time() - start_time < max_wait_time_seconds: status = self._get_batch_info(job_id, batch_id) logger.info( "Batch %s Job %s in state %s. %s records processed. %s records failed." % (batch_id, job_id, status["state"], status["num_processed"], status["num_failed"]) ) if status["state"].lower() in ["completed", "failed"]: return status time.sleep(sleep_time_seconds) raise Exception("Batch did not complete in %s seconds. Final status was: %s" % (sleep_time_seconds, status)) def get_batch_results(self, job_id, batch_id): """ DEPRECATED: Use `get_batch_result_ids` """ warnings.warn("get_batch_results is deprecated and only returns one batch result. Please use get_batch_result_ids") return self.get_batch_result_ids(job_id, batch_id)[0] def get_batch_result_ids(self, job_id, batch_id): """ Get result IDs of a batch that has completed processing. :param job_id: job_id as returned by 'create_operation_job(...)' :param batch_id: batch_id as returned by 'create_batch(...)' :return: list of batch result IDs to be used in 'get_batch_result(...)' """ response = requests.get(self._get_batch_results_url(job_id, batch_id), headers=self._get_batch_info_headers()) response.raise_for_status() root = ET.fromstring(response.text) result_ids = [r.text for r in root.findall("%sresult" % self.API_NS)] return result_ids def get_batch_result(self, job_id, batch_id, result_id): """ Gets result back from Salesforce as whatever type was originally sent in create_batch (xml, or csv). :param job_id: :param batch_id: :param result_id: """ response = requests.get(self._get_batch_result_url(job_id, batch_id, result_id), headers=self._get_session_headers()) response.raise_for_status() return response.content def _get_batch_info(self, job_id, batch_id): response = requests.get(self._get_batch_info_url(job_id, batch_id), headers=self._get_batch_info_headers()) response.raise_for_status() root = ET.fromstring(response.text) result = { "state": root.find("%sstate" % self.API_NS).text, "num_processed": root.find("%snumberRecordsProcessed" % self.API_NS).text, "num_failed": root.find("%snumberRecordsFailed" % self.API_NS).text, } if root.find("%sstateMessage" % self.API_NS) is not None: result["state_message"] = root.find("%sstateMessage" % self.API_NS).text return result def _get_login_url(self): server = "login" if not self.sandbox_name else "test" return "https://%s.salesforce.com/services/Soap/u/%s" % (server, self.API_VERSION) def _get_base_url(self): return "https://%s/services" % self.hostname def _get_bulk_base_url(self): # Expands on Base Url for Bulk return "%s/async/%s" % (self._get_base_url(), self.API_VERSION) def _get_norm_base_url(self): # Expands on Base Url for Norm return "%s/data/v%s" % (self._get_base_url(), self.API_VERSION) def _get_norm_query_url(self): # Expands on Norm Base Url return "%s/query" % self._get_norm_base_url() def _get_create_job_url(self): # Expands on Bulk url return "%s/job" % (self._get_bulk_base_url()) def _get_job_id_url(self, job_id): # Expands on Job Creation url return "%s/%s" % (self._get_create_job_url(), job_id) def _get_job_details_url(self, job_id): # Expands on basic Job Id url return self._get_job_id_url(job_id) def _get_abort_job_url(self, job_id): # Expands on basic Job Id url return self._get_job_id_url(job_id) def _get_close_job_url(self, job_id): # Expands on basic Job Id url return self._get_job_id_url(job_id) def _get_create_batch_url(self, job_id): # Expands on basic Job Id url return "%s/batch" % (self._get_job_id_url(job_id)) def _get_batch_info_url(self, job_id, batch_id): # Expands on Batch Creation url return "%s/%s" % (self._get_create_batch_url(job_id), batch_id) def _get_batch_results_url(self, job_id, batch_id): # Expands on Batch Info url return "%s/result" % (self._get_batch_info_url(job_id, batch_id)) def _get_batch_result_url(self, job_id, batch_id, result_id): # Expands on Batch Results url return "%s/%s" % (self._get_batch_results_url(job_id, batch_id), result_id) def _get_login_headers(self): headers = {"Content-Type": "text/xml; charset=UTF-8", "SOAPAction": "login"} return headers def _get_session_headers(self): headers = {"X-SFDC-Session": self.session_id} return headers def _get_norm_session_headers(self): headers = {"Authorization": "Bearer %s" % self.session_id} return headers def _get_rest_headers(self): headers = self._get_norm_session_headers() headers["Content-Type"] = "application/json" return headers def _get_job_headers(self): headers = self._get_session_headers() headers["Content-Type"] = "application/xml; charset=UTF-8" return headers def _get_create_job_headers(self): return self._get_job_headers() def _get_abort_job_headers(self): return self._get_job_headers() def _get_close_job_headers(self): return self._get_job_headers() def _get_create_batch_content_headers(self, content_type): headers = self._get_session_headers() content_type = "text/csv" if content_type.lower() == "csv" else "application/xml" headers["Content-Type"] = "%s; charset=UTF-8" % content_type return headers def _get_batch_info_headers(self): return self._get_session_headers() def _get_login_xml(self): return """ %s %s%s """ % (self.username, self.password, self.security_token if self.sandbox_name is None else self.sb_security_token) def _get_create_job_xml(self, operation, obj, external_id_field_name, content_type): external_id_field_name_element = "" if not external_id_field_name else "\n%s" % external_id_field_name # Note: "Unable to parse job" error may be caused by reordering fields. # ExternalIdFieldName element must be before contentType element. return """ %s %s %s %s """ % (operation, obj, external_id_field_name_element, content_type) def _get_abort_job_xml(self): return """ Aborted """ def _get_close_job_xml(self): return """ Closed """ ================================================ FILE: luigi/contrib/scalding.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import logging import os import re import subprocess import warnings import luigi.configuration import luigi.contrib.hadoop import luigi.contrib.hadoop_jar import luigi.contrib.hdfs from luigi import LocalTarget from luigi.task import flatten logger = logging.getLogger("luigi-interface") """ Scalding support for Luigi. Example configuration section in luigi.cfg:: [scalding] # scala home directory, which should include a lib subdir with scala jars. scala-home: /usr/share/scala # scalding home directory, which should include a lib subdir with # scalding-*-assembly-* jars as built from the official Twitter build script. scalding-home: /usr/share/scalding # provided dependencies, e.g. jars required for compiling but not executing # scalding jobs. Currently required jars: # org.apache.hadoop/hadoop-core/0.20.2 # org.slf4j/slf4j-log4j12/1.6.6 # log4j/log4j/1.2.15 # commons-httpclient/commons-httpclient/3.1 # commons-cli/commons-cli/1.2 # org.apache.zookeeper/zookeeper/3.3.4 scalding-provided: /usr/share/scalding/provided # additional jars required. scalding-libjars: /usr/share/scalding/libjars """ class ScaldingJobRunner(luigi.contrib.hadoop.JobRunner): """ JobRunner for `pyscald` commands. Used to run a ScaldingJobTask. """ def __init__(self): conf = luigi.configuration.get_config() default = os.environ.get("SCALA_HOME", "/usr/share/scala") self.scala_home = conf.get("scalding", "scala-home", default) default = os.environ.get("SCALDING_HOME", "/usr/share/scalding") self.scalding_home = conf.get("scalding", "scalding-home", default) self.provided_dir = conf.get("scalding", "scalding-provided", os.path.join(default, "provided")) self.libjars_dir = conf.get("scalding", "scalding-libjars", os.path.join(default, "libjars")) self.tmp_dir = LocalTarget(is_tmp=True) def _get_jars(self, path): return [os.path.join(path, j) for j in os.listdir(path) if j.endswith(".jar")] def get_scala_jars(self, include_compiler=False): lib_dir = os.path.join(self.scala_home, "lib") jars = [os.path.join(lib_dir, "scala-library.jar")] # additional jar for scala 2.10 only reflect = os.path.join(lib_dir, "scala-reflect.jar") if os.path.exists(reflect): jars.append(reflect) if include_compiler: jars.append(os.path.join(lib_dir, "scala-compiler.jar")) return jars def get_scalding_jars(self): lib_dir = os.path.join(self.scalding_home, "lib") return self._get_jars(lib_dir) def get_scalding_core(self): lib_dir = os.path.join(self.scalding_home, "lib") for j in os.listdir(lib_dir): if j.startswith("scalding-core-"): p = os.path.join(lib_dir, j) logger.debug("Found scalding-core: %s", p) return p raise luigi.contrib.hadoop.HadoopJobError("Could not find scalding-core.") def get_provided_jars(self): return self._get_jars(self.provided_dir) def get_libjars(self): return self._get_jars(self.libjars_dir) def get_tmp_job_jar(self, source): job_name = os.path.basename(os.path.splitext(source)[0]) return os.path.join(self.tmp_dir.path, job_name + ".jar") def get_build_dir(self, source): build_dir = os.path.join(self.tmp_dir.path, "build") return build_dir def get_job_class(self, source): # find name of the job class # usually the one that matches file name or last class that extends Job job_name = os.path.splitext(os.path.basename(source))[0] package = None job_class = None for line in open(source).readlines(): p = re.search(r"package\s+([^\s\(]+)", line) if p: package = p.groups()[0] p = re.search(r"class\s+([^\s\(]+).*extends\s+.*Job", line) if p: job_class = p.groups()[0] if job_class == job_name: break if job_class: if package: job_class = package + "." + job_class logger.debug("Found scalding job class: %s", job_class) return job_class else: raise luigi.contrib.hadoop.HadoopJobError("Coudl not find scalding job class.") def build_job_jar(self, job): job_jar = job.jar() if job_jar: if not os.path.exists(job_jar): logger.error("Can't find jar: %s, full path %s", job_jar, os.path.abspath(job_jar)) raise Exception("job jar does not exist") if not job.job_class(): logger.error("Undefined job_class()") raise Exception("Undefined job_class()") return job_jar job_src = job.source() if not job_src: logger.error("Both source() and jar() undefined") raise Exception("Both source() and jar() undefined") if not os.path.exists(job_src): logger.error("Can't find source: %s, full path %s", job_src, os.path.abspath(job_src)) raise Exception("job source does not exist") job_src = job.source() job_jar = self.get_tmp_job_jar(job_src) build_dir = self.get_build_dir(job_src) if not os.path.exists(build_dir): os.makedirs(build_dir) classpath = ":".join(filter(None, self.get_scalding_jars() + self.get_provided_jars() + self.get_libjars() + job.extra_jars())) scala_cp = ":".join(self.get_scala_jars(include_compiler=True)) # compile scala source arglist = ["java", "-cp", scala_cp, "scala.tools.nsc.Main", "-classpath", classpath, "-d", build_dir, job_src] logger.info("Compiling scala source: %s", subprocess.list2cmdline(arglist)) subprocess.check_call(arglist) # build job jar file arglist = ["jar", "cf", job_jar, "-C", build_dir, "."] logger.info("Building job jar: %s", subprocess.list2cmdline(arglist)) subprocess.check_call(arglist) return job_jar def run_job(self, job, tracking_url_callback=None): if tracking_url_callback is not None: warnings.warn("tracking_url_callback argument is deprecated, task.set_tracking_url is used instead.", DeprecationWarning) job_jar = self.build_job_jar(job) jars = [job_jar] + self.get_libjars() + job.extra_jars() scalding_core = self.get_scalding_core() libjars = ",".join(filter(None, jars)) arglist = luigi.contrib.hdfs.load_hadoop_cmd() + ["jar", scalding_core, "-libjars", libjars] arglist += ["-D%s" % c for c in job.jobconfs()] job_class = job.job_class() or self.get_job_class(job.source()) arglist += [job_class, "--hdfs"] # scalding does not parse argument with '=' properly arglist += ["--name", job.task_id.replace("=", ":")] (tmp_files, job_args) = luigi.contrib.hadoop_jar.fix_paths(job) arglist += job_args env = os.environ.copy() jars.append(scalding_core) hadoop_cp = ":".join(filter(None, jars)) env["HADOOP_CLASSPATH"] = hadoop_cp logger.info("Submitting Hadoop job: HADOOP_CLASSPATH=%s %s", hadoop_cp, subprocess.list2cmdline(arglist)) luigi.contrib.hadoop.run_and_track_hadoop_job(arglist, job.set_tracking_url, env=env) for a, b in tmp_files: a.move(b) class ScaldingJobTask(luigi.contrib.hadoop.BaseHadoopJobTask): """ A job task for Scalding that define a scala source and (optional) main method. requires() should return a dictionary where the keys are Scalding argument names and values are sub tasks or lists of subtasks. For example: .. code-block:: python {'input1': A, 'input2': C} => --input1 --input2 {'input1': [A, B], 'input2': [C]} => --input1 --input2 """ def relpath(self, current_file, rel_path): """ Compute path given current file and relative path. """ script_dir = os.path.dirname(os.path.abspath(current_file)) rel_path = os.path.abspath(os.path.join(script_dir, rel_path)) return rel_path def source(self): """ Path to the scala source for this Scalding Job Either one of source() or jar() must be specified. """ return None def jar(self): """ Path to the jar file for this Scalding Job Either one of source() or jar() must be specified. """ return None def extra_jars(self): """ Extra jars for building and running this Scalding Job. """ return [] def job_class(self): """ optional main job class for this Scalding Job. """ return None def job_runner(self): return ScaldingJobRunner() def atomic_output(self): """ If True, then rewrite output arguments to be temp locations and atomically move them into place after the job finishes. """ return True def requires(self): return {} def job_args(self): """ Extra arguments to pass to the Scalding job. """ return [] def args(self): """ Returns an array of args to pass to the job. """ arglist = [] for k, v in self.requires_hadoop().items(): arglist.append("--" + k) arglist.extend([t.output().path for t in flatten(v)]) arglist.extend(["--output", self.output()]) arglist.extend(self.job_args()) return arglist ================================================ FILE: luigi/contrib/sge.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """SGE batch system Tasks. Adapted by Jake Feala (@jfeala) from `LSF extension `_ by Alex Wiltschko (@alexbw) Maintained by Jake Feala (@jfeala) SunGrid Engine is a job scheduler used to allocate compute resources on a shared cluster. Jobs are submitted using the ``qsub`` command and monitored using ``qstat``. To get started, install luigi on all nodes. To run luigi workflows on an SGE cluster, subclass :class:`luigi.contrib.sge.SGEJobTask` as you would any :class:`luigi.Task`, but override the ``work()`` method, instead of ``run()``, to define the job code. Then, run your Luigi workflow from the master node, assigning > 1 ``workers`` in order to distribute the tasks in parallel across the cluster. The following is an example usage (and can also be found in ``sge_tests.py``) .. code-block:: python import logging import luigi import os from luigi.contrib.sge import SGEJobTask logger = logging.getLogger('luigi-interface') class TestJobTask(SGEJobTask): i = luigi.Parameter() def work(self): logger.info('Running test job...') with open(self.output().path, 'w') as f: f.write('this is a test') def output(self): return luigi.LocalTarget(os.path.join('/home', 'testfile_' + str(self.i))) if __name__ == '__main__': tasks = [TestJobTask(i=str(i), n_cpu=i+1) for i in range(3)] luigi.build(tasks, local_scheduler=True, workers=3) The ``n-cpu`` parameter allows you to define different compute resource requirements (or slots, in SGE terms) for each task. In this example, the third Task asks for 3 CPU slots. If your cluster only contains nodes with 2 CPUs, this task will hang indefinitely in the queue. See the docs for :class:`luigi.contrib.sge.SGEJobTask` for other SGE parameters. As for any task, you can also set these in your luigi configuration file as shown below. The default values below were matched to the values used by MIT StarCluster, an open-source SGE cluster manager for use with Amazon EC2:: [SGEJobTask] shared-tmp-dir = /home parallel-env = orte n-cpu = 2 """ # This extension is modeled after the hadoop.py approach. # # Implementation notes # The procedure: # - Pickle the class # - Construct a qsub argument that runs a generic runner function with the path to the pickled class # - Runner function loads the class from pickle # - Runner function hits the work button on it import logging import os import pickle import random import subprocess import sys import time import luigi from luigi.contrib import sge_runner from luigi.contrib.hadoop import create_packages_archive logger = logging.getLogger("luigi-interface") logger.propagate = False POLL_TIME = 5 # decided to hard-code rather than configure here def _parse_qstat_state(qstat_out, job_id): """Parse "state" column from `qstat` output for given job_id Returns state for the *first* job matching job_id. Returns 'u' if `qstat` output is empty or job_id is not found. """ if qstat_out.strip() == "": return "u" lines = qstat_out.split("\n") # skip past header while not lines.pop(0).startswith("---"): pass for line in lines: if line: job, prior, name, user, state = line.strip().split()[0:5] if int(job) == int(job_id): return state return "u" def _parse_qsub_job_id(qsub_out): """Parse job id from qsub output string. Assume format: "Your job ("") has been submitted" """ return int(qsub_out.split()[2]) def _build_qsub_command(cmd, job_name, outfile, errfile, pe, n_cpu): """Submit shell command to SGE queue via `qsub`""" qsub_template = """echo {cmd} | qsub -o ":{outfile}" -e ":{errfile}" -V -r y -pe {pe} {n_cpu} -N {job_name}""" return qsub_template.format(cmd=cmd, job_name=job_name, outfile=outfile, errfile=errfile, pe=pe, n_cpu=n_cpu) class SGEJobTask(luigi.Task): """Base class for executing a job on SunGrid Engine Override ``work()`` (rather than ``run()``) with your job code. Parameters: - n_cpu: Number of CPUs (or "slots") to allocate for the Task. This value is passed as ``qsub -pe {pe} {n_cpu}`` - parallel_env: SGE parallel environment name. The default is "orte", the parallel environment installed with MIT StarCluster. If you are using a different cluster environment, check with your sysadmin for the right pe to use. This value is passed as {pe} to the qsub command above. - shared_tmp_dir: Shared drive accessible from all nodes in the cluster. Task classes and dependencies are pickled to a temporary folder on this drive. The default is ``/home``, the NFS share location setup by StarCluster - job_name_format: String that can be passed in to customize the job name string passed to qsub; e.g. "Task123_{task_family}_{n_cpu}...". - job_name: Exact job name to pass to qsub. - run_locally: Run locally instead of on the cluster. - poll_time: the length of time to wait in order to poll qstat - dont_remove_tmp_dir: Instead of deleting the temporary directory, keep it. - no_tarball: Don't create a tarball of the luigi project directory. Can be useful to reduce I/O requirements when the luigi directory is accessible from cluster nodes already. """ n_cpu = luigi.IntParameter(default=2, significant=False) shared_tmp_dir = luigi.Parameter(default="/home", significant=False) parallel_env = luigi.Parameter(default="orte", significant=False) job_name_format = luigi.Parameter( significant=False, default=None, description="A string that can be formatted with class variables to name the job with qsub." ) job_name = luigi.Parameter(significant=False, default=None, description="Explicit job name given via qsub.") run_locally = luigi.BoolParameter(significant=False, description="run locally instead of on the cluster") poll_time = luigi.IntParameter(significant=False, default=POLL_TIME, description="specify the wait time to poll qstat for the job status") dont_remove_tmp_dir = luigi.BoolParameter(significant=False, description="don't delete the temporary directory used (for debugging)") no_tarball = luigi.BoolParameter(significant=False, description="don't tarball (and extract) the luigi project files") def __init__(self, *args, **kwargs): super(SGEJobTask, self).__init__(*args, **kwargs) if self.job_name: # use explicitly provided job name pass elif self.job_name_format: # define the job name with the provided format self.job_name = self.job_name_format.format(task_family=self.task_family, **self.__dict__) else: # default to the task family self.job_name = self.task_family def _fetch_task_failures(self): if not os.path.exists(self.errfile): logger.info("No error file") return [] with open(self.errfile, "r") as f: errors = f.readlines() if errors == []: return errors if errors[0].strip() == "stdin: is not a tty": # SGE complains when we submit through a pipe errors.pop(0) return errors def _init_local(self): # Set up temp folder in shared directory (trim to max filename length) base_tmp_dir = self.shared_tmp_dir random_id = "%016x" % random.getrandbits(64) folder_name = self.task_id + "-" + random_id self.tmp_dir = os.path.join(base_tmp_dir, folder_name) max_filename_length = os.fstatvfs(0).f_namemax self.tmp_dir = self.tmp_dir[:max_filename_length] logger.info("Tmp dir: %s", self.tmp_dir) os.makedirs(self.tmp_dir) # Dump the code to be run into a pickle file logging.debug("Dumping pickled class") self._dump(self.tmp_dir) if not self.no_tarball: # Make sure that all the class's dependencies are tarred and available # This is not necessary if luigi is importable from the cluster node logging.debug("Tarballing dependencies") # Grab luigi and the module containing the code to be run packages = [luigi] + [__import__(self.__module__, None, None, "dummy")] create_packages_archive(packages, os.path.join(self.tmp_dir, "packages.tar")) def run(self): if self.run_locally: self.work() else: self._init_local() self._run_job() # The procedure: # - Pickle the class # - Tarball the dependencies # - Construct a qsub argument that runs a generic runner function with the path to the pickled class # - Runner function loads the class from pickle # - Runner class untars the dependencies # - Runner function hits the button on the class's work() method def work(self): """Override this method, rather than ``run()``, for your actual work.""" pass def _dump(self, out_dir=""): """Dump instance to file.""" with self.no_unpicklable_properties(): self.job_file = os.path.join(out_dir, "job-instance.pickle") if self.__module__ == "__main__": d = pickle.dumps(self) module_name = os.path.basename(sys.argv[0]).rsplit(".", 1)[0] d = d.replace("(c__main__", "(c" + module_name) with open(self.job_file, "w") as f: f.write(d) else: with open(self.job_file, "wb") as f: pickle.dump(self, f) def _run_job(self): # Build a qsub argument that will run sge_runner.py on the directory we've specified runner_path = sge_runner.__file__ if runner_path.endswith("pyc"): runner_path = runner_path[:-3] + "py" job_str = 'python {0} "{1}" "{2}"'.format(runner_path, self.tmp_dir, os.getcwd()) # enclose tmp_dir in quotes to protect from special escape chars if self.no_tarball: job_str += ' "--no-tarball"' # Build qsub submit command self.outfile = os.path.join(self.tmp_dir, "job.out") self.errfile = os.path.join(self.tmp_dir, "job.err") submit_cmd = _build_qsub_command(job_str, self.task_family, self.outfile, self.errfile, self.parallel_env, self.n_cpu) logger.debug("qsub command: \n" + submit_cmd) # Submit the job and grab job ID output = subprocess.check_output(submit_cmd, shell=True) self.job_id = _parse_qsub_job_id(output) logger.debug("Submitted job to qsub with response:\n" + output) self._track_job() # Now delete the temporaries, if they're there. if self.tmp_dir and os.path.exists(self.tmp_dir) and not self.dont_remove_tmp_dir: logger.info("Removing temporary directory %s" % self.tmp_dir) subprocess.call(["rm", "-rf", self.tmp_dir]) def _track_job(self): while True: # Sleep for a little bit time.sleep(self.poll_time) # See what the job's up to # ASSUMPTION qstat_out = subprocess.check_output(["qstat"]) sge_status = _parse_qstat_state(qstat_out, self.job_id) if sge_status == "r": logger.info("Job is running...") elif sge_status == "qw": logger.info("Job is pending...") elif "E" in sge_status: logger.error("Job has FAILED:\n" + "\n".join(self._fetch_task_failures())) break elif sge_status == "t" or sge_status == "u": # Then the job could either be failed or done. errors = self._fetch_task_failures() if not errors: logger.info("Job is done") else: logger.error("Job has FAILED:\n" + "\n".join(errors)) break else: logger.info("Job status is UNKNOWN!") logger.info("Status is : %s" % sge_status) raise Exception("job status isn't one of ['r', 'qw', 'E*', 't', 'u']: %s" % sge_status) class LocalSGEJobTask(SGEJobTask): """A local version of SGEJobTask, for easier debugging. This version skips the ``qsub`` steps and simply runs ``work()`` on the local node, so you don't need to be on an SGE cluster to use your Task in a test workflow. """ def run(self): self.work() ================================================ FILE: luigi/contrib/sge_runner.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ The SunGrid Engine runner The main() function of this module will be executed on the compute node by the submitted job. It accepts as a single argument the shared temp folder containing the package archive and pickled task to run, and carries out these steps: - extract tarfile of package dependencies and place on the path - unpickle SGETask instance created on the master node - run SGETask.work() On completion, SGETask on the master node will detect that the job has left the queue, delete the temporary folder, and return from SGETask.run() """ import logging import os import pickle import sys from luigi.safe_extractor import SafeExtractor def _do_work_on_compute_node(work_dir, tarball=True): if tarball: # Extract the necessary dependencies # This can create a lot of I/O overhead when running many SGEJobTasks, # so is optional if the luigi project is accessible from the cluster node _extract_packages_archive(work_dir) # Open up the pickle file with the work to be done os.chdir(work_dir) with open("job-instance.pickle", "r") as f: job = pickle.load(f) # Do the work contained job.work() def _extract_packages_archive(work_dir): package_file = os.path.join(work_dir, "packages.tar") if not os.path.exists(package_file): return curdir = os.path.abspath(os.curdir) os.chdir(work_dir) extractor = SafeExtractor(work_dir) extractor.safe_extract(package_file) if "" not in sys.path: sys.path.insert(0, "") os.chdir(curdir) def main(args=sys.argv): """Run the work() method from the class instance in the file "job-instance.pickle".""" try: tarball = "--no-tarball" not in args # Set up logging. logging.basicConfig(level=logging.WARN) work_dir = args[1] assert os.path.exists(work_dir), "First argument to sge_runner.py must be a directory that exists" project_dir = args[2] sys.path.append(project_dir) _do_work_on_compute_node(work_dir, tarball) except Exception as e: # Dump encoded data that we will try to fetch using mechanize print(e) raise if __name__ == "__main__": main() ================================================ FILE: luigi/contrib/simulate.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ A module containing classes used to simulate certain behaviors """ import hashlib import logging import os import tempfile from multiprocessing import Value import luigi logger = logging.getLogger("luigi-interface") class RunAnywayTarget(luigi.Target): """ A target used to make a task run every time it is called. Usage: Pass `self` as the first argument in your task's `output`: .. code-block: python def output(self): return RunAnywayTarget(self) And then mark it as `done` in your task's `run`: .. code-block: python def run(self): # Your task execution # ... self.output().done() # will then be considered as "existing" """ # Specify the location of the temporary folder storing the state files. Subclass to change this value temp_dir = os.path.join(tempfile.gettempdir(), "luigi-simulate") temp_time = 24 * 3600 # seconds # Unique value (PID of the first encountered target) to separate temporary files between executions and # avoid deletion collision unique = Value("i", 0) def __init__(self, task_obj): self.task_id = task_obj.task_id if self.unique.value == 0: with self.unique.get_lock(): if self.unique.value == 0: self.unique.value = os.getpid() # The PID will be unique for every execution of the pipeline # Deleting old files > temp_time if os.path.isdir(self.temp_dir): import shutil import time limit = time.time() - self.temp_time for fn in os.listdir(self.temp_dir): path = os.path.join(self.temp_dir, fn) if os.path.isdir(path) and os.stat(path).st_mtime < limit: shutil.rmtree(path) logger.debug("Deleted temporary directory %s", path) def __str__(self): return self.task_id def get_path(self): """ Returns a temporary file path based on a MD5 hash generated with the task's name and its arguments """ md5_hash = hashlib.new("md5", self.task_id.encode(), usedforsecurity=False).hexdigest() logger.debug("Hash %s corresponds to task %s", md5_hash, self.task_id) return os.path.join(self.temp_dir, str(self.unique.value), md5_hash) def exists(self): """ Checks if the file exists """ return os.path.isfile(self.get_path()) def done(self): """ Creates temporary file to mark the task as `done` """ logger.info("Marking %s as done", self) fn = self.get_path() try: os.makedirs(os.path.dirname(fn)) except OSError: pass open(fn, "w").close() ================================================ FILE: luigi/contrib/spark.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import collections import importlib import inspect import logging import os import pickle import re import shutil import sys import tarfile import tempfile from luigi import configuration from luigi.contrib.external_program import ExternalProgramTask logger = logging.getLogger("luigi-interface") class SparkSubmitTask(ExternalProgramTask): """ Template task for running a Spark job Supports running jobs on Spark local, standalone, Mesos or Yarn See http://spark.apache.org/docs/latest/submitting-applications.html for more information """ # Application (.jar or .py file) name = None entry_class = None app = None # Only log stderr if spark fails (since stderr is normally quite verbose) always_log_stderr = False # Spark applications write its logs into stderr stream_for_searching_tracking_url = "stderr" @property def tracking_url_pattern(self): if self.deploy_mode == "cluster": # in cluster mode client only receives application status once a period of time return r"tracking URL: (https?://.*)\s" else: return r"Bound (?:.*) to (?:.*), and started at (https?://.*)\s" def app_options(self): """ Subclass this method to map your task parameters to the app's arguments """ return [] @property def pyspark_python(self): return None @property def pyspark_driver_python(self): return None @property def hadoop_user_name(self): return None @property def spark_version(self): return "spark" @property def spark_submit(self): return configuration.get_config().get(self.spark_version, "spark-submit", "spark-submit") @property def master(self): return configuration.get_config().get(self.spark_version, "master", None) @property def deploy_mode(self): return configuration.get_config().get(self.spark_version, "deploy-mode", None) @property def jars(self): return self._list_config(configuration.get_config().get(self.spark_version, "jars", None)) @property def packages(self): return self._list_config(configuration.get_config().get(self.spark_version, "packages", None)) @property def py_files(self): return self._list_config(configuration.get_config().get(self.spark_version, "py-files", None)) @property def files(self): return self._list_config(configuration.get_config().get(self.spark_version, "files", None)) @property def _conf(self): conf = collections.OrderedDict(self.conf or {}) if self.pyspark_python: conf["spark.pyspark.python"] = self.pyspark_python if self.pyspark_driver_python: conf["spark.pyspark.driver.python"] = self.pyspark_driver_python return conf @property def conf(self): return self._dict_config(configuration.get_config().get(self.spark_version, "conf", None)) @property def properties_file(self): return configuration.get_config().get(self.spark_version, "properties-file", None) @property def driver_memory(self): return configuration.get_config().get(self.spark_version, "driver-memory", None) @property def driver_java_options(self): return configuration.get_config().get(self.spark_version, "driver-java-options", None) @property def driver_library_path(self): return configuration.get_config().get(self.spark_version, "driver-library-path", None) @property def driver_class_path(self): return configuration.get_config().get(self.spark_version, "driver-class-path", None) @property def executor_memory(self): return configuration.get_config().get(self.spark_version, "executor-memory", None) @property def driver_cores(self): return configuration.get_config().get(self.spark_version, "driver-cores", None) @property def supervise(self): return bool(configuration.get_config().get(self.spark_version, "supervise", False)) @property def total_executor_cores(self): return configuration.get_config().get(self.spark_version, "total-executor-cores", None) @property def executor_cores(self): return configuration.get_config().get(self.spark_version, "executor-cores", None) @property def queue(self): return configuration.get_config().get(self.spark_version, "queue", None) @property def num_executors(self): return configuration.get_config().get(self.spark_version, "num-executors", None) @property def archives(self): return self._list_config(configuration.get_config().get(self.spark_version, "archives", None)) @property def hadoop_conf_dir(self): return configuration.get_config().get(self.spark_version, "hadoop-conf-dir", None) def get_environment(self): env = os.environ.copy() for prop in ("HADOOP_CONF_DIR", "HADOOP_USER_NAME"): var = getattr(self, prop.lower(), None) if var: env[prop] = var return env def program_environment(self): return self.get_environment() def program_args(self): return self.spark_command() + self.app_command() def spark_command(self): command = [self.spark_submit] command += self._text_arg("--master", self.master) command += self._text_arg("--deploy-mode", self.deploy_mode) command += self._text_arg("--name", self.name) command += self._text_arg("--class", self.entry_class) command += self._list_arg("--jars", self.jars) command += self._list_arg("--packages", self.packages) command += self._list_arg("--py-files", self.py_files) command += self._list_arg("--files", self.files) command += self._list_arg("--archives", self.archives) command += self._dict_arg("--conf", self._conf) command += self._text_arg("--properties-file", self.properties_file) command += self._text_arg("--driver-memory", self.driver_memory) command += self._text_arg("--driver-java-options", self.driver_java_options) command += self._text_arg("--driver-library-path", self.driver_library_path) command += self._text_arg("--driver-class-path", self.driver_class_path) command += self._text_arg("--executor-memory", self.executor_memory) command += self._text_arg("--driver-cores", self.driver_cores) command += self._flag_arg("--supervise", self.supervise) command += self._text_arg("--total-executor-cores", self.total_executor_cores) command += self._text_arg("--executor-cores", self.executor_cores) command += self._text_arg("--queue", self.queue) command += self._text_arg("--num-executors", self.num_executors) return command def app_command(self): if not self.app: raise NotImplementedError("subclass should define an app (.jar or .py file)") return [self.app] + self.app_options() def _list_config(self, config): if config and isinstance(config, str): return list(map(lambda x: x.strip(), config.split(","))) def _dict_config(self, config): if config and isinstance(config, str): return dict(map(lambda i: i.split("=", 1), config.split("|"))) def _text_arg(self, name, value): if value: return [name, value] return [] def _list_arg(self, name, value): if value and isinstance(value, (list, tuple)): return [name, ",".join(value)] return [] def _dict_arg(self, name, value): command = [] if value and isinstance(value, dict): for prop, value in value.items(): command += [name, "{0}={1}".format(prop, value)] return command def _flag_arg(self, name, value): if value: return [name] return [] class PySparkTask(SparkSubmitTask): """ Template task for running an inline PySpark job Simply implement the ``main`` method in your subclass You can optionally define package names to be distributed to the cluster with ``py_packages`` (uses luigi's global py-packages configuration by default) """ # Path to the pyspark program passed to spark-submit app = os.path.join(os.path.dirname(__file__), "pyspark_runner.py") @property def name(self): return self.__class__.__name__ @property def py_packages(self): packages = configuration.get_config().get("spark", "py-packages", None) if packages: return map(lambda s: s.strip(), packages.split(",")) @property def files(self): if self.deploy_mode == "cluster": return [self.run_pickle] @property def pickle_protocol(self): return configuration.get_config().getint("spark", "pickle-protocol", pickle.DEFAULT_PROTOCOL) def setup(self, conf): """ Called by the pyspark_runner with a SparkConf instance that will be used to instantiate the SparkContext :param conf: SparkConf """ def setup_remote(self, sc): self._setup_packages(sc) def main(self, sc, *args): """ Called by the pyspark_runner with a SparkContext and any arguments returned by ``app_options()`` :param sc: SparkContext :param args: arguments list """ raise NotImplementedError("subclass should define a main method") def app_command(self): if self.deploy_mode == "cluster": pickle_loc = os.path.basename(self.run_pickle) else: pickle_loc = self.run_pickle return [self.app, pickle_loc] + self.app_options() def run(self): path_name_fragment = re.sub(r"[^\w]", "_", self.name) self.run_path = tempfile.mkdtemp(prefix=path_name_fragment) self.run_pickle = os.path.join(self.run_path, ".".join([path_name_fragment, "pickle"])) with open(self.run_pickle, "wb") as fd: # Copy module file to run path. module_path = os.path.abspath(inspect.getfile(self.__class__)) shutil.copy(module_path, os.path.join(self.run_path, ".")) self._dump(fd) try: super(PySparkTask, self).run() finally: shutil.rmtree(self.run_path) def _dump(self, fd): with self.no_unpicklable_properties(): if self.__module__ == "__main__": d = pickle.dumps(self, protocol=self.pickle_protocol) module_name = os.path.basename(sys.argv[0]).rsplit(".", 1)[0] d = d.replace(b"c__main__", b"c" + module_name.encode("ascii")) fd.write(d) else: pickle.dump(self, fd, protocol=self.pickle_protocol) def _setup_packages(self, sc): """ This method compresses and uploads packages to the cluster """ packages = self.py_packages if not packages: return for package in packages: mod = importlib.import_module(package) try: mod_path = mod.__path__[0] except AttributeError: mod_path = mod.__file__ os.makedirs(self.run_path, exist_ok=True) tar_path = os.path.join(self.run_path, package + ".tar.gz") tar = tarfile.open(tar_path, "w:gz") tar.add(mod_path, os.path.basename(mod_path)) tar.close() sc.addPyFile(tar_path) ================================================ FILE: luigi/contrib/sparkey.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import luigi class SparkeyExportTask(luigi.Task): """ A luigi task that writes to a local sparkey log file. Subclasses should implement the requires and output methods. The output must be a luigi.LocalTarget. The resulting sparkey log file will contain one entry for every line in the input, mapping from the first value to a tab-separated list of the rest of the line. To generate a simple key-value index, yield "key", "value" pairs from the input(s) to this task. """ # the separator used to split input lines separator = "\t" def __init__(self, *args, **kwargs): super(SparkeyExportTask, self).__init__(*args, **kwargs) def run(self): self._write_sparkey_file() def _write_sparkey_file(self): import sparkey infile = self.input() outfile = self.output() if not isinstance(outfile, luigi.LocalTarget): raise TypeError("output must be a LocalTarget") # write job output to temporary sparkey file temp_output = luigi.LocalTarget(is_tmp=True) w = sparkey.LogWriter(temp_output.path) for line in infile.open("r"): k, v = line.strip().split(self.separator, 1) w[k] = v w.close() # move finished sparkey file to final destination temp_output.move(outfile.path) ================================================ FILE: luigi/contrib/sqla.py ================================================ # -*- coding: utf-8 -*- # # Copyright (c) 2015 Gouthaman Balaraman # # Licensed under the Apache License, Version 2.0 (the "License"); you may not # use this file except in compliance with the License. You may obtain a copy of # the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations under # the License. # """ Support for SQLAlchemy. Provides SQLAlchemyTarget for storing in databases supported by SQLAlchemy. The user would be responsible for installing the required database driver to connect using SQLAlchemy. Minimal example of a job to copy data to database using SQLAlchemy is as shown below: .. code-block:: python from sqlalchemy import String import luigi from luigi.contrib import sqla class SQLATask(sqla.CopyToTable): # columns defines the table schema, with each element corresponding # to a column in the format (args, kwargs) which will be sent to # the sqlalchemy.Column(*args, **kwargs) columns = [ (["item", String(64)], {"primary_key": True}), (["property", String(64)], {}) ] connection_string = "sqlite://" # in memory SQLite database table = "item_property" # name of the table to store data def rows(self): for row in [("item1", "property1"), ("item2", "property2")]: yield row if __name__ == '__main__': task = SQLATask() luigi.build([task], local_scheduler=True) If the target table where the data needs to be copied already exists, then the column schema definition can be skipped and instead the reflect flag can be set as True. Here is a modified version of the above example: .. code-block:: python from sqlalchemy import String import luigi from luigi.contrib import sqla class SQLATask(sqla.CopyToTable): # If database table is already created, then the schema can be loaded # by setting the reflect flag to True reflect = True connection_string = "sqlite://" # in memory SQLite database table = "item_property" # name of the table to store data def rows(self): for row in [("item1", "property1"), ("item2", "property2")]: yield row if __name__ == '__main__': task = SQLATask() luigi.build([task], local_scheduler=True) In the above examples, the data that needs to be copied was directly provided by overriding the rows method. Alternately, if the data comes from another task, the modified example would look as shown below: .. code-block:: python from sqlalchemy import String import luigi from luigi.contrib import sqla from luigi.mock import MockTarget class BaseTask(luigi.Task): def output(self): return MockTarget("BaseTask") def run(self): out = self.output().open("w") TASK_LIST = ["item%d\\tproperty%d\\n" % (i, i) for i in range(10)] for task in TASK_LIST: out.write(task) out.close() class SQLATask(sqla.CopyToTable): # columns defines the table schema, with each element corresponding # to a column in the format (args, kwargs) which will be sent to # the sqlalchemy.Column(*args, **kwargs) columns = [ (["item", String(64)], {"primary_key": True}), (["property", String(64)], {}) ] connection_string = "sqlite://" # in memory SQLite database table = "item_property" # name of the table to store data def requires(self): return BaseTask() if __name__ == '__main__': task1, task2 = SQLATask(), BaseTask() luigi.build([task1, task2], local_scheduler=True) In the above example, the output from `BaseTask` is copied into the database. Here we did not have to implement the `rows` method because by default `rows` implementation assumes every line is a row with column values separated by a tab. One can define `column_separator` option for the task if the values are say comma separated instead of tab separated. You can pass in database specific connection arguments by setting the connect_args dictionary. The options will be passed directly to the DBAPI's connect method as keyword arguments. The other option to `sqla.CopyToTable` that can be of help with performance aspect is the `chunk_size`. The default is 5000. This is the number of rows that will be inserted in a transaction at a time. Depending on the size of the inserts, this value can be tuned for performance. See here for a `tutorial on building task pipelines using luigi `_ and using `SQLAlchemy in workflow pipelines `_. Author: Gouthaman Balaraman Date: 01/02/2015 """ import abc import collections import datetime import itertools import logging import os import sqlalchemy import luigi class SQLAlchemyTarget(luigi.Target): """ Database target using SQLAlchemy. This will rarely have to be directly instantiated by the user. Typical usage would be to override `luigi.contrib.sqla.CopyToTable` class to create a task to write to the database. """ marker_table = None _engine_dict = {} # dict of sqlalchemy engine instances Connection = collections.namedtuple("Connection", "engine pid") def __init__(self, connection_string, target_table, update_id, echo=False, connect_args=None): """ Constructor for the SQLAlchemyTarget. :param connection_string: SQLAlchemy connection string :type connection_string: str :param target_table: The table name for the data :type target_table: str :param update_id: An identifier for this data set :type update_id: str :param echo: Flag to setup SQLAlchemy logging :type echo: bool :param connect_args: A dictionary of connection arguments :type connect_args: dict :return: """ if connect_args is None: connect_args = {} self.target_table = target_table self.update_id = update_id self.connection_string = connection_string self.echo = echo self.connect_args = connect_args self.marker_table_bound = None def __str__(self): return self.target_table @property def engine(self): """ Return an engine instance, creating it if it doesn't exist. Recreate the engine connection if it wasn't originally created by the current process. """ pid = os.getpid() conn = SQLAlchemyTarget._engine_dict.get(self.connection_string) if not conn or conn.pid != pid: # create and reset connection engine = sqlalchemy.create_engine(self.connection_string, connect_args=self.connect_args, echo=self.echo) SQLAlchemyTarget._engine_dict[self.connection_string] = self.Connection(engine, pid) return SQLAlchemyTarget._engine_dict[self.connection_string].engine def touch(self): """ Mark this update as complete. """ if self.marker_table_bound is None: self.create_marker_table() table = self.marker_table_bound id_exists = self.exists() with self.engine.begin() as conn: if not id_exists: ins = table.insert().values(update_id=self.update_id, target_table=self.target_table, inserted=datetime.datetime.now()) else: ins = ( table.update() .where(sqlalchemy.and_(table.c.update_id == self.update_id, table.c.target_table == self.target_table)) .values(update_id=self.update_id, target_table=self.target_table, inserted=datetime.datetime.now()) ) conn.execute(ins) assert self.exists() def exists(self): row = None if self.marker_table_bound is None: self.create_marker_table() with self.engine.begin() as conn: table = self.marker_table_bound s = sqlalchemy.select([table]).where(sqlalchemy.and_(table.c.update_id == self.update_id, table.c.target_table == self.target_table)).limit(1) row = conn.execute(s).fetchone() return row is not None def create_marker_table(self): """ Create marker table if it doesn't exist. Using a separate connection since the transaction might have to be reset. """ if self.marker_table is None: self.marker_table = luigi.configuration.get_config().get("sqlalchemy", "marker-table", "table_updates") engine = self.engine with engine.begin() as con: metadata = sqlalchemy.MetaData() if not con.dialect.has_table(con, self.marker_table): self.marker_table_bound = sqlalchemy.Table( self.marker_table, metadata, sqlalchemy.Column("update_id", sqlalchemy.String(128), primary_key=True), sqlalchemy.Column("target_table", sqlalchemy.String(128)), sqlalchemy.Column("inserted", sqlalchemy.DateTime, default=datetime.datetime.now()), ) metadata.create_all(engine) else: metadata.reflect(only=[self.marker_table], bind=engine) self.marker_table_bound = metadata.tables[self.marker_table] def open(self, mode): raise NotImplementedError("Cannot open() SQLAlchemyTarget") class CopyToTable(luigi.Task): """ An abstract task for inserting a data set into SQLAlchemy RDBMS Usage: * subclass and override the required `connection_string`, `table` and `columns` attributes. * optionally override the `schema` attribute to use a different schema for the target table. """ _logger = logging.getLogger("luigi-interface") echo = False connect_args = {} @property @abc.abstractmethod def connection_string(self): return None @property @abc.abstractmethod def table(self): return None # specify the columns that define the schema. The format for the columns is a list # of tuples. For example : # columns = [ # (["id", sqlalchemy.Integer], dict(primary_key=True)), # (["name", sqlalchemy.String(64)], {}), # (["value", sqlalchemy.String(64)], {}) # ] # The tuple (args_list, kwargs_dict) here is the args and kwargs # that need to be passed to sqlalchemy.Column(*args, **kwargs). # If the tables have already been setup by another process, then you can # completely ignore the columns. Instead set the reflect value to True below columns = [] # Specify the database schema of the target table, if supported by the # RDBMS. Note that this doesn't change the schema of the marker table. # The schema MUST already exist in the database, or this will task fail. schema = "" # options column_separator = "\t" # how columns are separated in the file copied into postgres chunk_size = 5000 # default chunk size for insert reflect = False # Set this to true only if the table has already been created by alternate means def create_table(self, engine): """ Override to provide code for creating the target table. By default it will be created using types specified in columns. If the table exists, then it binds to the existing table. If overridden, use the provided connection object for setting up the table in order to create the table and insert data using the same transaction. :param engine: The sqlalchemy engine instance :type engine: object """ def construct_sqla_columns(columns): retval = [sqlalchemy.Column(*c[0], **c[1]) for c in columns] return retval needs_setup = (len(self.columns) == 0) or (False in [len(c) == 2 for c in self.columns]) if not self.reflect else False if needs_setup: # only names of columns specified, no types raise NotImplementedError("create_table() not implemented for %r and columns types not specified" % self.table) else: # if columns is specified as (name, type) tuples with engine.begin() as con: if self.schema: metadata = sqlalchemy.MetaData(schema=self.schema) else: metadata = sqlalchemy.MetaData() try: if not con.dialect.has_table(con, self.table, self.schema or None): sqla_columns = construct_sqla_columns(self.columns) self.table_bound = sqlalchemy.Table(self.table, metadata, *sqla_columns) metadata.create_all(engine) else: full_table = ".".join([self.schema, self.table]) if self.schema else self.table metadata.reflect(only=[self.table], bind=engine) self.table_bound = metadata.tables[full_table] except Exception as e: self._logger.exception(self.table + str(e)) def update_id(self): """ This update id will be a unique identifier for this insert on this table. """ return self.task_id def output(self): return SQLAlchemyTarget( connection_string=self.connection_string, target_table=self.table, update_id=self.update_id(), connect_args=self.connect_args, echo=self.echo ) def rows(self): """ Return/yield tuples or lists corresponding to each row to be inserted. This method can be overridden for custom file types or formats. """ with self.input().open("r") as fobj: for line in fobj: yield line.strip("\n").split(self.column_separator) def run(self): self._logger.info("Running task copy to table for update id %s for table %s" % (self.update_id(), self.table)) output = self.output() engine = output.engine self.create_table(engine) with engine.begin() as conn: rows = iter(self.rows()) ins_rows = [dict(zip(("_" + c.key for c in self.table_bound.c), row)) for row in itertools.islice(rows, self.chunk_size)] while ins_rows: self.copy(conn, ins_rows, self.table_bound) ins_rows = [dict(zip(("_" + c.key for c in self.table_bound.c), row)) for row in itertools.islice(rows, self.chunk_size)] self._logger.info("Finished inserting %d rows into SQLAlchemy target" % len(ins_rows)) output.touch() self._logger.info("Finished inserting rows into SQLAlchemy target") def copy(self, conn, ins_rows, table_bound): """ This method does the actual insertion of the rows of data given by ins_rows into the database. A task that needs row updates instead of insertions should overload this method. :param conn: The sqlalchemy connection object :param ins_rows: The dictionary of rows with the keys in the format _. For example if you have a table with a column name "property", then the key in the dictionary would be "_property". This format is consistent with the bindparam usage in sqlalchemy. :param table_bound: The object referring to the table :return: """ bound_cols = dict((c, sqlalchemy.bindparam("_" + c.key)) for c in table_bound.columns) ins = table_bound.insert().values(bound_cols) conn.execute(ins, ins_rows) ================================================ FILE: luigi/contrib/ssh.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ Light-weight remote execution library and utilities. There are some examples in the unittest but I added another that is more luigi-specific in the examples directory (examples/ssh_remote_execution.py) :class:`RemoteContext` is meant to provide functionality similar to that of the standard library subprocess module, but where the commands executed are run on a remote machine instead, without the user having to think about prefixing everything with "ssh" and credentials etc. Using this mini library (which is just a convenience wrapper for subprocess), :class:`RemoteTarget` is created to let you stream data from a remotely stored file using the luigi :class:`~luigi.target.FileSystemTarget` semantics. As a bonus, :class:`RemoteContext` also provides a really cool feature that let's you set up ssh tunnels super easily using a python context manager (there is an example in the integration part of unittests). This can be super convenient when you want secure communication using a non-secure protocol or circumvent firewalls (as long as they are open for ssh traffic). """ import contextlib import logging import os import posixpath import random import subprocess import luigi import luigi.format import luigi.target logger = logging.getLogger("luigi-interface") class RemoteCalledProcessError(subprocess.CalledProcessError): def __init__(self, returncode, command, host, output=None): super(RemoteCalledProcessError, self).__init__(returncode, command, output) self.host = host def __str__(self): return "Command '%s' on host %s returned non-zero exit status %d" % (self.cmd, self.host, self.returncode) class RemoteContext: def __init__(self, host, **kwargs): self.host = host self.username = kwargs.get("username", None) self.key_file = kwargs.get("key_file", None) self.connect_timeout = kwargs.get("connect_timeout", None) self.port = kwargs.get("port", None) self.no_host_key_check = kwargs.get("no_host_key_check", False) self.sshpass = kwargs.get("sshpass", False) self.tty = kwargs.get("tty", False) def __repr__(self): return "%s(%r, %r, %r, %r, %r)" % (type(self).__name__, self.host, self.username, self.key_file, self.connect_timeout, self.port) def __eq__(self, other): return repr(self) == repr(other) def __hash__(self): return hash(repr(self)) def _host_ref(self): if self.username: return "{0}@{1}".format(self.username, self.host) else: return self.host def _prepare_cmd(self, cmd): connection_cmd = ["ssh", self._host_ref(), "-o", "ControlMaster=no"] if self.sshpass: connection_cmd = ["sshpass", "-e"] + connection_cmd else: connection_cmd += ["-o", "BatchMode=yes"] # no password prompts etc if self.port: connection_cmd.extend(["-p", self.port]) if self.connect_timeout is not None: connection_cmd += ["-o", "ConnectTimeout=%d" % self.connect_timeout] if self.no_host_key_check: connection_cmd += ["-o", "UserKnownHostsFile=/dev/null", "-o", "StrictHostKeyChecking=no"] if self.key_file: connection_cmd.extend(["-i", self.key_file]) if self.tty: connection_cmd.append("-t") return connection_cmd + cmd def Popen(self, cmd, **kwargs): """ Remote Popen. """ prefixed_cmd = self._prepare_cmd(cmd) return subprocess.Popen(prefixed_cmd, **kwargs) def check_output(self, cmd): """ Execute a shell command remotely and return the output. Simplified version of Popen when you only want the output as a string and detect any errors. """ p = self.Popen(cmd, stdout=subprocess.PIPE) output, _ = p.communicate() if p.returncode != 0: raise RemoteCalledProcessError(p.returncode, cmd, self.host, output=output) return output @contextlib.contextmanager def tunnel(self, local_port, remote_port=None, remote_host="localhost"): """ Open a tunnel between localhost:local_port and remote_host:remote_port via the host specified by this context. Remember to close() the returned "tunnel" object in order to clean up after yourself when you are done with the tunnel. """ tunnel_host = "{0}:{1}:{2}".format(local_port, remote_host, remote_port) proc = self.Popen( # cat so we can shut down gracefully by closing stdin ["-L", tunnel_host, "echo -n ready && cat"], stdin=subprocess.PIPE, stdout=subprocess.PIPE, ) # make sure to get the data so we know the connection is established ready = proc.stdout.read(5) assert ready == b"ready", "Didn't get ready from remote echo" yield # user code executed here proc.communicate() assert proc.returncode == 0, "Tunnel process did an unclean exit (returncode %s)" % (proc.returncode,) class RemoteFileSystem(luigi.target.FileSystem): def __init__(self, host, **kwargs): self.remote_context = RemoteContext(host, **kwargs) def exists(self, path): """ Return `True` if file or directory at `path` exist, False otherwise. """ try: self.remote_context.check_output(["test", "-e", path]) except subprocess.CalledProcessError as e: if e.returncode == 1: return False else: raise return True def listdir(self, path): while path.endswith("/"): path = path[:-1] path = path or "." listing = self.remote_context.check_output(["find", "-L", path, "-type", "f"]).splitlines() return [v.decode("utf-8") for v in listing] def isdir(self, path): """ Return `True` if directory at `path` exist, False otherwise. """ try: self.remote_context.check_output(["test", "-d", path]) except subprocess.CalledProcessError as e: if e.returncode == 1: return False else: raise return True def remove(self, path, recursive=True): """ Remove file or directory at location `path`. """ if recursive: cmd = ["rm", "-r", path] else: cmd = ["rm", path] self.remote_context.check_output(cmd) def mkdir(self, path, parents=True, raise_if_exists=False): if self.exists(path): if raise_if_exists: raise luigi.target.FileAlreadyExists() elif not self.isdir(path): raise luigi.target.NotADirectory() else: return if parents: cmd = ["mkdir", "-p", path] else: cmd = ["mkdir", path, "2>&1"] try: self.remote_context.check_output(cmd) except subprocess.CalledProcessError as e: if b"no such file" in e.output.lower(): raise luigi.target.MissingParentDirectory() raise def _scp(self, src, dest): cmd = ["scp", "-q", "-C", "-o", "ControlMaster=no"] if self.remote_context.sshpass: cmd = ["sshpass", "-e"] + cmd else: cmd.append("-B") if self.remote_context.no_host_key_check: cmd.extend(["-o", "UserKnownHostsFile=/dev/null", "-o", "StrictHostKeyChecking=no"]) if self.remote_context.key_file: cmd.extend(["-i", self.remote_context.key_file]) if self.remote_context.port: cmd.extend(["-P", self.remote_context.port]) if os.path.isdir(src): cmd.extend(["-r"]) cmd.extend([src, dest]) p = subprocess.Popen(cmd) output, _ = p.communicate() if p.returncode != 0: raise subprocess.CalledProcessError(p.returncode, cmd, output=output) def put(self, local_path, path): # create parent folder if not exists normpath = posixpath.normpath(path) folder = os.path.dirname(normpath) if folder and not self.exists(folder): self.remote_context.check_output(["mkdir", "-p", folder]) tmp_path = path + "-luigi-tmp-%09d" % random.randrange(0, 10_000_000_000) self._scp(local_path, "%s:%s" % (self.remote_context._host_ref(), tmp_path)) self.remote_context.check_output(["mv", tmp_path, path]) def get(self, path, local_path): # Create folder if it does not exist normpath = os.path.normpath(local_path) folder = os.path.dirname(normpath) if folder: try: os.makedirs(folder) except OSError: pass tmp_local_path = local_path + "-luigi-tmp-%09d" % random.randrange(0, 10_000_000_000) self._scp("%s:%s" % (self.remote_context._host_ref(), path), tmp_local_path) os.replace(tmp_local_path, local_path) class AtomicRemoteFileWriter(luigi.format.OutputPipeProcessWrapper): def __init__(self, fs, path): self._fs = fs self.path = path # create parent folder if not exists normpath = os.path.normpath(self.path) folder = os.path.dirname(normpath) if folder: self.fs.mkdir(folder) self.__tmp_path = self.path + "-luigi-tmp-%09d" % random.randrange(0, 10_000_000_000) super(AtomicRemoteFileWriter, self).__init__(self.fs.remote_context._prepare_cmd(["cat", ">", self.__tmp_path])) def __del__(self): super(AtomicRemoteFileWriter, self).__del__() try: if self.fs.exists(self.__tmp_path): self.fs.remote_context.check_output(["rm", self.__tmp_path]) except Exception: # Don't propagate the exception; bad things can happen. logger.exception("Failed to delete in-flight file") def close(self): super(AtomicRemoteFileWriter, self).close() self.fs.remote_context.check_output(["mv", self.__tmp_path, self.path]) @property def tmp_path(self): return self.__tmp_path @property def fs(self): return self._fs class RemoteTarget(luigi.target.FileSystemTarget): """ Target used for reading from remote files. The target is implemented using ssh commands streaming data over the network. """ def __init__(self, path, host, format=None, **kwargs): super(RemoteTarget, self).__init__(path) if format is None: format = luigi.format.get_default_format() self.format = format self._fs = RemoteFileSystem(host, **kwargs) @property def fs(self): return self._fs def open(self, mode="r"): if mode == "w": file_writer = AtomicRemoteFileWriter(self.fs, self.path) if self.format: return self.format.pipe_writer(file_writer) else: return file_writer elif mode == "r": file_reader = luigi.format.InputPipeProcessWrapper(self.fs.remote_context._prepare_cmd(["cat", self.path])) if self.format: return self.format.pipe_reader(file_reader) else: return file_reader else: raise Exception("mode must be 'r' or 'w' (got: %s)" % mode) def put(self, local_path): self.fs.put(local_path, self.path) def get(self, local_path): self.fs.get(self.path, local_path) ================================================ FILE: luigi/contrib/target.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import logging from types import MethodType import luigi.target logger = logging.getLogger("luigi-interface") class CascadingClient: """ A FilesystemClient that will cascade failing function calls through a list of clients. Which clients are used are specified at time of construction. """ # This constant member is supposed to include all methods, feel free to add # methods here. If you want full control of which methods that should be # created, pass the kwarg to the constructor. ALL_METHOD_NAMES = [ "exists", "rename", "remove", "chmod", "chown", "count", "copy", "get", "put", "mkdir", "list", "listdir", "getmerge", "isdir", "rename_dont_move", "touchz", ] def __init__(self, clients, method_names=None): self.clients = clients if method_names is None: method_names = self.ALL_METHOD_NAMES for method_name in method_names: new_method = self._make_method(method_name) real_method = MethodType(new_method, self) setattr(self, method_name, real_method) @classmethod def _make_method(cls, method_name): def new_method(self, *args, **kwargs): return self._chained_call(method_name, *args, **kwargs) return new_method def _chained_call(self, method_name, *args, **kwargs): for i in range(len(self.clients)): client = self.clients[i] try: result = getattr(client, method_name)(*args, **kwargs) return result except luigi.target.FileSystemException: # For exceptions that are semantical, we must throw along raise except BaseException: is_last_iteration = (i + 1) >= len(self.clients) if is_last_iteration: raise else: logger.warning( "The %s failed to %s, using fallback class %s", client.__class__.__name__, method_name, self.clients[i + 1].__class__.__name__ ) ================================================ FILE: luigi/contrib/webhdfs.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ Provides a :class:`WebHdfsTarget` using the `Python hdfs `_ This module is DEPRECATED and does not play well with rest of luigi's hdfs contrib module. You can consider migrating to :class:`luigi.contrib.hdfs.webhdfs_client.WebHdfsClient` """ import logging import luigi.contrib.hdfs from luigi.format import get_default_format from luigi.target import AtomicLocalFile, FileSystemTarget logger = logging.getLogger("luigi-interface") class WebHdfsTarget(FileSystemTarget): fs = None def __init__(self, path, client=None, format=None): super(WebHdfsTarget, self).__init__(path) path = self.path self.fs = client or WebHdfsClient() if format is None: format = get_default_format() self.format = format def open(self, mode="r"): if mode not in ("r", "w"): raise ValueError("Unsupported open mode '%s'" % mode) if mode == "r": return self.format.pipe_reader(ReadableWebHdfsFile(path=self.path, client=self.fs)) return self.format.pipe_writer(AtomicWebHdfsFile(path=self.path, client=self.fs)) class ReadableWebHdfsFile: def __init__(self, path, client): self.path = path self.client = client self.generator = None def read(self): self.generator = self.client.read(self.path) res = list(self.generator)[0] return res def readlines(self, char="\n"): self.generator = self.client.read(self.path, buffer_char=char) return self.generator def __enter__(self): return self def __exit__(self, exc_type, exc, traceback): self.close() def __iter__(self): self.generator = self.readlines("\n") yield from self.generator self.close() def close(self): self.generator.close() class AtomicWebHdfsFile(AtomicLocalFile): """ An Hdfs file that writes to a temp file and put to WebHdfs on close. """ def __init__(self, path, client): self.client = client super(AtomicWebHdfsFile, self).__init__(path) def move_to_final_destination(self): if not self.client.exists(self.path): self.client.upload(self.path, self.tmp_path) WebHdfsClient = luigi.contrib.hdfs.WebHdfsClient ================================================ FILE: luigi/date_interval.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ ``luigi.date_interval`` provides convenient classes for date algebra. Everything uses ISO 8601 notation, i.e. YYYY-MM-DD for dates, etc. There is a corresponding :class:`luigi.parameter.DateIntervalParameter` that you can use to parse date intervals. Example:: class MyTask(luigi.Task): date_interval = luigi.DateIntervalParameter() Now, you can launch this from the command line using ``--date-interval 2014-05-10`` or ``--date-interval 2014-W26`` (using week notation) or ``--date-interval 2014`` (for a year) and some other notations. """ import datetime import re class DateInterval: """ The :class:`DateInterval` is the base class with subclasses :class:`Date`, :class:`Week`, :class:`Month`, :class:`Year`, and :class:`Custom`. Note that the :class:`DateInterval` is abstract and should not be used directly: use :class:`Custom` for arbitrary date intervals. The base class features a couple of convenience methods, such as ``next()`` which returns the next consecutive date interval. Example:: x = luigi.date_interval.Week(2013, 52) print x.prev() This will print ``2014-W01``. All instances of :class:`DateInterval` have attributes ``date_a`` and ``date_b`` set. This represents the half open range of the date interval. For instance, a May 2014 is represented as ``date_a = 2014-05-01``, ``date_b = 2014-06-01``. """ def __init__(self, date_a, date_b): self.date_a = date_a self.date_b = date_b def dates(self): """Returns a list of dates in this date interval.""" dates = [] d = self.date_a while d < self.date_b: dates.append(d) d += datetime.timedelta(1) return dates def hours(self): """Same as dates() but returns 24 times more info: one for each hour.""" for date in self.dates(): for hour in range(24): yield datetime.datetime.combine(date, datetime.time(hour)) def __str__(self): return self.to_string() def __repr__(self): return self.to_string() def prev(self): """Returns the preceding corresponding date interval (eg. May -> April).""" return self.from_date(self.date_a - datetime.timedelta(1)) def next(self): """Returns the subsequent corresponding date interval (eg. 2014 -> 2015).""" return self.from_date(self.date_b) def to_string(self): raise NotImplementedError @classmethod def from_date(cls, d): """Abstract class method. For instance, ``Month.from_date(datetime.date(2012, 6, 6))`` returns a ``Month(2012, 6)``.""" raise NotImplementedError @classmethod def parse(cls, s): """Abstract class method. For instance, ``Year.parse("2014")`` returns a ``Year(2014)``.""" raise NotImplementedError def __contains__(self, date): return date in self.dates() def __iter__(self): for d in self.dates(): yield d def __hash__(self): return hash(repr(self)) def __cmp__(self, other): if not isinstance(self, type(other)): # doing this because it's not well defined if eg. 2012-01-01-2013-01-01 == 2012 raise TypeError("Date interval type mismatch") return (self > other) - (self < other) def __lt__(self, other): if not isinstance(self, type(other)): raise TypeError("Date interval type mismatch") return (self.date_a, self.date_b) < (other.date_a, other.date_b) def __le__(self, other): if not isinstance(self, type(other)): raise TypeError("Date interval type mismatch") return (self.date_a, self.date_b) <= (other.date_a, other.date_b) def __gt__(self, other): if not isinstance(self, type(other)): raise TypeError("Date interval type mismatch") return (self.date_a, self.date_b) > (other.date_a, other.date_b) def __ge__(self, other): if not isinstance(self, type(other)): raise TypeError("Date interval type mismatch") return (self.date_a, self.date_b) >= (other.date_a, other.date_b) def __eq__(self, other): if not isinstance(other, DateInterval): return False if not isinstance(self, type(other)): raise TypeError("Date interval type mismatch") else: return (self.date_a, self.date_b) == (other.date_a, other.date_b) def __ne__(self, other): return not self.__eq__(other) class Date(DateInterval): """Most simple :class:`DateInterval` where ``date_b == date_a + datetime.timedelta(1)``.""" def __init__(self, y, m, d): a = datetime.date(y, m, d) b = datetime.date(y, m, d) + datetime.timedelta(1) super(Date, self).__init__(a, b) def to_string(self): return self.date_a.strftime("%Y-%m-%d") @classmethod def from_date(cls, d): return Date(d.year, d.month, d.day) @classmethod def parse(cls, s): if re.match(r"\d\d\d\d\-\d\d\-\d\d$", s): return Date(*map(int, s.split("-"))) class Week(DateInterval): """ISO 8601 week. Note that it has some counterintuitive behavior around new year. For instance Monday 29 December 2008 is week 2009-W01, and Sunday 3 January 2010 is week 2009-W53 This example was taken from from http://en.wikipedia.org/wiki/ISO_8601#Week_dates """ def __init__(self, y, w): """Python datetime does not have a method to convert from ISO weeks, so the constructor uses some stupid brute force""" for d in range(-10, 370): date = datetime.date(y, 1, 1) + datetime.timedelta(d) if date.isocalendar() == (y, w, 1): date_a = date break else: raise ValueError("Invalid week") date_b = date_a + datetime.timedelta(7) super(Week, self).__init__(date_a, date_b) def to_string(self): return "%d-W%02d" % self.date_a.isocalendar()[:2] @classmethod def from_date(cls, d): return Week(*d.isocalendar()[:2]) @classmethod def parse(cls, s): if re.match(r"\d\d\d\d\-W\d\d$", s): y, w = map(int, s.split("-W")) return Week(y, w) class Month(DateInterval): def __init__(self, y, m): date_a = datetime.date(y, m, 1) date_b = datetime.date(y + m // 12, 1 + m % 12, 1) super(Month, self).__init__(date_a, date_b) def to_string(self): return self.date_a.strftime("%Y-%m") @classmethod def from_date(cls, d): return Month(d.year, d.month) @classmethod def parse(cls, s): if re.match(r"\d\d\d\d\-\d\d$", s): y, m = map(int, s.split("-")) return Month(y, m) class Year(DateInterval): def __init__(self, y): date_a = datetime.date(y, 1, 1) date_b = datetime.date(y + 1, 1, 1) super(Year, self).__init__(date_a, date_b) def to_string(self): return self.date_a.strftime("%Y") @classmethod def from_date(cls, d): return Year(d.year) @classmethod def parse(cls, s): if re.match(r"\d\d\d\d$", s): return Year(int(s)) class Custom(DateInterval): """Custom date interval (does not implement prev and next methods) Actually the ISO 8601 specifies / as the time interval format Not sure if this goes for date intervals as well. In any case slashes will most likely cause problems with paths etc. """ def to_string(self): return "-".join([d.strftime("%Y-%m-%d") for d in (self.date_a, self.date_b)]) @classmethod def parse(cls, s): if re.match(r"\d\d\d\d\-\d\d\-\d\d\-\d\d\d\d\-\d\d\-\d\d$", s): x = list(map(int, s.split("-"))) date_a = datetime.date(*x[:3]) date_b = datetime.date(*x[3:]) return Custom(date_a, date_b) ================================================ FILE: luigi/db_task_history.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ Provides a database backend to the central scheduler. This lets you see historical runs. See :ref:`TaskHistory` for information about how to turn out the task history feature. """ # # Description: Added codes for visualization of how long each task takes # running-time until it reaches the next status (failed or done) # At "{base_url}/tasklist", all completed(failed or done) tasks are shown. # At "{base_url}/tasklist", a user can select one specific task to see # how its running-time has changed over time. # At "{base_url}/tasklist/{task_name}", it visualizes a multi-bar graph # that represents the changes of the running-time for a selected task # up to the next status (failed or done). # This visualization let us know how the running-time of the specific task # has changed over time. # # Copyright 2015 Naver Corp. # Author Yeseul Park (yeseul.park@navercorp.com) # import datetime import logging from contextlib import contextmanager import sqlalchemy import sqlalchemy.ext.declarative import sqlalchemy.orm import sqlalchemy.orm.collections from sqlalchemy.engine import reflection from luigi import configuration, task_history from luigi.task_status import DONE, FAILED, PENDING, RUNNING Base = sqlalchemy.ext.declarative.declarative_base() logger = logging.getLogger("luigi-interface") class DbTaskHistory(task_history.TaskHistory): """ Task History that writes to a database using sqlalchemy. Also has methods for useful db queries. """ CURRENT_SOURCE_VERSION = 1 @contextmanager def _session(self, session=None): if session: yield session else: session = self.session_factory() try: yield session except BaseException: session.rollback() raise else: session.commit() def __init__(self): config = configuration.get_config() connection_string = config.get("task_history", "db_connection") self.engine = sqlalchemy.create_engine(connection_string) self.session_factory = sqlalchemy.orm.sessionmaker(bind=self.engine, expire_on_commit=False) Base.metadata.create_all(self.engine) self.tasks = {} # task_id -> TaskRecord _upgrade_schema(self.engine) def task_scheduled(self, task): htask = self._get_task(task, status=PENDING) self._add_task_event(htask, TaskEvent(event_name=PENDING, ts=datetime.datetime.now())) def task_finished(self, task, successful): event_name = DONE if successful else FAILED htask = self._get_task(task, status=event_name) self._add_task_event(htask, TaskEvent(event_name=event_name, ts=datetime.datetime.now())) def task_started(self, task, worker_host): htask = self._get_task(task, status=RUNNING, host=worker_host) self._add_task_event(htask, TaskEvent(event_name=RUNNING, ts=datetime.datetime.now())) def _get_task(self, task, status, host=None): if task.id in self.tasks: htask = self.tasks[task.id] htask.status = status if host: htask.host = host else: htask = self.tasks[task.id] = task_history.StoredTask(task, status, host) return htask def _add_task_event(self, task, event): for task_record, session in self._find_or_create_task(task): task_record.events.append(event) def _find_or_create_task(self, task): with self._session() as session: if task.record_id is not None: logger.debug("Finding task with record_id [%d]", task.record_id) task_record = session.query(TaskRecord).get(task.record_id) if not task_record: raise Exception("Task with record_id, but no matching Task record!") yield (task_record, session) else: task_record = TaskRecord(task_id=task._task.id, name=task.task_family, host=task.host) for k, v in task.parameters.items(): task_record.parameters[k] = TaskParameter(name=k, value=v) session.add(task_record) yield (task_record, session) if task.host: task_record.host = task.host task.record_id = task_record.id def find_all_by_parameters(self, task_name, session=None, **task_params): """ Find tasks with the given task_name and the same parameters as the kwargs. """ with self._session(session) as session: query = session.query(TaskRecord).join(TaskEvent).filter(TaskRecord.name == task_name) for k, v in task_params.items(): alias = sqlalchemy.orm.aliased(TaskParameter) query = query.join(alias).filter(alias.name == k, alias.value == v) tasks = query.order_by(TaskEvent.ts) for task in tasks: # Sanity check assert all(k in task.parameters and v == str(task.parameters[k].value) for k, v in task_params.items()) yield task def find_all_by_name(self, task_name, session=None): """ Find all tasks with the given task_name. """ return self.find_all_by_parameters(task_name, session) def find_latest_runs(self, session=None): """ Return tasks that have been updated in the past 24 hours. """ with self._session(session) as session: yesterday = datetime.datetime.now() - datetime.timedelta(days=1) return ( session.query(TaskRecord) .join(TaskEvent) .filter(TaskEvent.ts >= yesterday) .group_by(TaskRecord.id, TaskEvent.event_name, TaskEvent.ts) .order_by(TaskEvent.ts.desc()) .all() ) def find_all_runs(self, session=None): """ Return all tasks that have been updated. """ with self._session(session) as session: return session.query(TaskRecord).all() def find_all_events(self, session=None): """ Return all running/failed/done events. """ with self._session(session) as session: return session.query(TaskEvent).all() def find_task_by_id(self, id, session=None): """ Find task with the given record ID. """ with self._session(session) as session: return session.query(TaskRecord).get(id) def find_task_by_task_id(self, task_id, session=None): """ Find task with the given task ID. """ with self._session(session) as session: return session.query(TaskRecord).filter(TaskRecord.task_id == task_id).all()[-1] class TaskParameter(Base): # type: ignore """ Table to track luigi.Parameter()s of a Task. """ __tablename__ = "task_parameters" task_id = sqlalchemy.Column(sqlalchemy.Integer, sqlalchemy.ForeignKey("tasks.id"), primary_key=True) name = sqlalchemy.Column(sqlalchemy.String(128), primary_key=True) value = sqlalchemy.Column(sqlalchemy.Text()) def __repr__(self): return "TaskParameter(task_id=%d, name=%s, value=%s)" % (self.task_id, self.name, self.value) class TaskEvent(Base): # type: ignore """ Table to track when a task is scheduled, starts, finishes, and fails. """ __tablename__ = "task_events" id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True) task_id = sqlalchemy.Column(sqlalchemy.Integer, sqlalchemy.ForeignKey("tasks.id"), index=True) event_name = sqlalchemy.Column(sqlalchemy.String(20)) ts = sqlalchemy.Column(sqlalchemy.TIMESTAMP, index=True, nullable=False) def __repr__(self): return "TaskEvent(task_id=%s, event_name=%s, ts=%s" % (self.task_id, self.event_name, self.ts) class TaskRecord(Base): # type: ignore """ Base table to track information about a luigi.Task. References to other tables are available through task.events, task.parameters, etc. """ __tablename__ = "tasks" id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True) task_id = sqlalchemy.Column(sqlalchemy.String(200), index=True) name = sqlalchemy.Column(sqlalchemy.String(128), index=True) host = sqlalchemy.Column(sqlalchemy.String(128)) parameters = sqlalchemy.orm.relationship( "TaskParameter", collection_class=sqlalchemy.orm.collections.attribute_mapped_collection("name"), cascade="all, delete-orphan" ) events = sqlalchemy.orm.relationship("TaskEvent", order_by=(sqlalchemy.desc(TaskEvent.ts), sqlalchemy.desc(TaskEvent.id)), backref="task") def __repr__(self): return "TaskRecord(name=%s, host=%s)" % (self.name, self.host) def _upgrade_schema(engine): """ Ensure the database schema is up to date with the codebase. :param engine: SQLAlchemy engine of the underlying database. """ inspector = reflection.Inspector.from_engine(engine) with engine.connect() as conn: # Upgrade 1. Add task_id column and index to tasks if "task_id" not in [x["name"] for x in inspector.get_columns("tasks")]: logger.warning("Upgrading DbTaskHistory schema: Adding tasks.task_id") conn.execute("ALTER TABLE tasks ADD COLUMN task_id VARCHAR(200)") conn.execute("CREATE INDEX ix_task_id ON tasks (task_id)") # Upgrade 2. Alter value column to be TEXT, note that this is idempotent so no if-guard if "mysql" in engine.dialect.name: conn.execute("ALTER TABLE task_parameters MODIFY COLUMN value TEXT") elif "oracle" in engine.dialect.name: conn.execute("ALTER TABLE task_parameters MODIFY value TEXT") elif "mssql" in engine.dialect.name: conn.execute("ALTER TABLE task_parameters ALTER COLUMN value TEXT") elif "postgresql" in engine.dialect.name: if str([x for x in inspector.get_columns("task_parameters") if x["name"] == "value"][0]["type"]) != "TEXT": conn.execute("ALTER TABLE task_parameters ALTER COLUMN value TYPE TEXT") elif "sqlite" in engine.dialect.name: # SQLite does not support changing column types. A database file will need # to be used to pickup this migration change. for i in conn.execute("PRAGMA table_info(task_parameters);").fetchall(): if i["name"] == "value" and i["type"] != "TEXT": logger.warning("SQLite can not change column types. Please use a new database to pickup column type changes.") else: logger.warning("SQLAlcheny dialect {} could not be migrated to the TEXT type".format(engine.dialect)) ================================================ FILE: luigi/event.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Definitions needed for events. See :ref:`Events` for info on how to use it.""" class Event: # TODO nice descriptive subclasses of Event instead of strings? pass their instances to the callback instead of an undocumented arg list? DEPENDENCY_DISCOVERED = "event.core.dependency.discovered" # triggered for every (task, upstream task) pair discovered in a jobflow DEPENDENCY_MISSING = "event.core.dependency.missing" DEPENDENCY_PRESENT = "event.core.dependency.present" BROKEN_TASK = "event.core.task.broken" START = "event.core.start" #: This event can be fired by the task itself while running. The purpose is #: for the task to report progress, metadata or any generic info so that #: event handler listening for this can keep track of the progress of running task. PROGRESS = "event.core.progress" FAILURE = "event.core.failure" SUCCESS = "event.core.success" PROCESSING_TIME = "event.core.processing_time" TIMEOUT = "event.core.timeout" # triggered if a task times out PROCESS_FAILURE = "event.core.process_failure" # triggered if the process a task is running in dies unexpectedly ================================================ FILE: luigi/execution_summary.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2015-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ This module provide the function :py:func:`summary` that is used for printing an `execution summary `_ at the end of luigi invocations. """ import collections import enum import functools import textwrap from luigi.parameter import IntParameter from luigi.task import Config class execution_summary(Config): summary_length = IntParameter(default=5) class LuigiStatusCode(enum.Enum): """ All possible status codes for the attribute ``status`` in :class:`~luigi.execution_summary.LuigiRunResult` when the argument ``detailed_summary=True`` in *luigi.run() / luigi.build*. Here are the codes and what they mean: ============================= ========================================================== Status Code Name Meaning ============================= ========================================================== SUCCESS There were no failed tasks or missing dependencies SUCCESS_WITH_RETRY There were failed tasks but they all succeeded in a retry FAILED There were failed tasks FAILED_AND_SCHEDULING_FAILED There were failed tasks and tasks whose scheduling failed SCHEDULING_FAILED There were tasks whose scheduling failed NOT_RUN There were tasks that were not granted run permission by the scheduler MISSING_EXT There were missing external dependencies ============================= ========================================================== """ SUCCESS = (":)", "there were no failed tasks or missing dependencies") SUCCESS_WITH_RETRY = (":)", "there were failed tasks but they all succeeded in a retry") FAILED = (":(", "there were failed tasks") FAILED_AND_SCHEDULING_FAILED = (":(", "there were failed tasks and tasks whose scheduling failed") SCHEDULING_FAILED = (":(", "there were tasks whose scheduling failed") NOT_RUN = (":|", "there were tasks that were not granted run permission by the scheduler") MISSING_EXT = (":|", "there were missing external dependencies") class LuigiRunResult: """ The result of a call to build/run when passing the detailed_summary=True argument. Attributes: - one_line_summary (str): One line summary of the progress. - summary_text (str): Detailed summary of the progress. - status (LuigiStatusCode): Luigi Status Code. See :class:`~luigi.execution_summary.LuigiStatusCode` for what these codes mean. - worker (luigi.worker.worker): Worker object. See :class:`~luigi.worker.worker`. - scheduling_succeeded (bool): Boolean which is *True* if all the tasks were scheduled without errors. """ def __init__(self, worker, worker_add_run_status=True): self.worker = worker summary_dict = _summary_dict(worker) self.summary_text = _summary_wrap(_summary_format(summary_dict, worker)) self.status = _tasks_status(summary_dict) self.one_line_summary = _create_one_line_summary(self.status) self.scheduling_succeeded = worker_add_run_status def __str__(self): return "LuigiRunResult with status {0}".format(self.status) def __repr__(self): return "LuigiRunResult(status={0!r},worker={1!r},scheduling_succeeded={2!r})".format(self.status, self.worker, self.scheduling_succeeded) def _partition_tasks(worker): """ Takes a worker and sorts out tasks based on their status. Still_pending_not_ext is only used to get upstream_failure, upstream_missing_dependency and run_by_other_worker """ task_history = worker._add_task_history pending_tasks = {task for (task, status, ext) in task_history if status == "PENDING"} set_tasks = {} set_tasks["completed"] = {task for (task, status, ext) in task_history if status == "DONE" and task in pending_tasks} set_tasks["already_done"] = { task for (task, status, ext) in task_history if status == "DONE" and task not in pending_tasks and task not in set_tasks["completed"] } set_tasks["ever_failed"] = {task for (task, status, ext) in task_history if status == "FAILED"} set_tasks["failed"] = set_tasks["ever_failed"] - set_tasks["completed"] set_tasks["scheduling_error"] = {task for (task, status, ext) in task_history if status == "UNKNOWN"} set_tasks["still_pending_ext"] = { task for (task, status, ext) in task_history if status == "PENDING" and task not in set_tasks["ever_failed"] and task not in set_tasks["completed"] and not ext } set_tasks["still_pending_not_ext"] = { task for (task, status, ext) in task_history if status == "PENDING" and task not in set_tasks["ever_failed"] and task not in set_tasks["completed"] and ext } set_tasks["run_by_other_worker"] = set() set_tasks["upstream_failure"] = set() set_tasks["upstream_missing_dependency"] = set() set_tasks["upstream_run_by_other_worker"] = set() set_tasks["upstream_scheduling_error"] = set() set_tasks["not_run"] = set() return set_tasks def _root_task(worker): """ Return the first task scheduled by the worker, corresponding to the root task """ return worker._add_task_history[0][0] def _populate_unknown_statuses(set_tasks): """ Add the "upstream_*" and "not_run" statuses my mutating set_tasks. """ visited = set() for task in set_tasks["still_pending_not_ext"]: _depth_first_search(set_tasks, task, visited) def _depth_first_search(set_tasks, current_task, visited): """ This dfs checks why tasks are still pending. """ visited.add(current_task) if current_task in set_tasks["still_pending_not_ext"]: upstream_failure = False upstream_missing_dependency = False upstream_run_by_other_worker = False upstream_scheduling_error = False for task in current_task._requires(): if task not in visited: _depth_first_search(set_tasks, task, visited) if task in set_tasks["ever_failed"] or task in set_tasks["upstream_failure"]: set_tasks["upstream_failure"].add(current_task) upstream_failure = True if task in set_tasks["still_pending_ext"] or task in set_tasks["upstream_missing_dependency"]: set_tasks["upstream_missing_dependency"].add(current_task) upstream_missing_dependency = True if task in set_tasks["run_by_other_worker"] or task in set_tasks["upstream_run_by_other_worker"]: set_tasks["upstream_run_by_other_worker"].add(current_task) upstream_run_by_other_worker = True if task in set_tasks["scheduling_error"]: set_tasks["upstream_scheduling_error"].add(current_task) upstream_scheduling_error = True if ( not upstream_failure and not upstream_missing_dependency and not upstream_run_by_other_worker and not upstream_scheduling_error and current_task not in set_tasks["run_by_other_worker"] ): set_tasks["not_run"].add(current_task) def _get_str(task_dict, extra_indent): """ This returns a string for each status """ summary_length = execution_summary().summary_length lines = [] task_names = sorted(task_dict.keys()) for task_family in task_names: tasks = task_dict[task_family] tasks = sorted(tasks, key=lambda x: str(x)) prefix_size = 8 if extra_indent else 4 prefix = " " * prefix_size line = None if summary_length > 0 and len(lines) >= summary_length: line = prefix + "..." lines.append(line) break if len(tasks[0].get_params()) == 0: line = prefix + "- {0} {1}()".format(len(tasks), str(task_family)) elif ( _get_len_of_params(tasks[0]) > 60 or len(str(tasks[0])) > 200 or (len(tasks) == 2 and len(tasks[0].get_params()) > 1 and (_get_len_of_params(tasks[0]) > 40 or len(str(tasks[0])) > 100)) ): """ This is to make sure that there is no really long task in the output """ line = prefix + "- {0} {1}(...)".format(len(tasks), task_family) elif len((tasks[0].get_params())) == 1: attributes = {getattr(task, tasks[0].get_params()[0][0]) for task in tasks} param_class = tasks[0].get_params()[0][1] first, last = _ranging_attributes(attributes, param_class) if first is not None and last is not None and len(attributes) > 3: param_str = "{0}...{1}".format(param_class.serialize(first), param_class.serialize(last)) else: param_str = "{0}".format(_get_str_one_parameter(tasks)) line = prefix + "- {0} {1}({2}={3})".format(len(tasks), task_family, tasks[0].get_params()[0][0], param_str) else: ranging = False params = _get_set_of_params(tasks) unique_param_keys = list(_get_unique_param_keys(params)) if len(unique_param_keys) == 1: (unique_param,) = unique_param_keys attributes = params[unique_param] param_class = unique_param[1] first, last = _ranging_attributes(attributes, param_class) if first is not None and last is not None and len(attributes) > 2: ranging = True line = prefix + "- {0} {1}({2}".format(len(tasks), task_family, _get_str_ranging_multiple_parameters(first, last, tasks, unique_param)) if not ranging: if len(tasks) == 1: line = prefix + "- {0} {1}".format(len(tasks), tasks[0]) if len(tasks) == 2: line = prefix + "- {0} {1} and {2}".format(len(tasks), tasks[0], tasks[1]) if len(tasks) > 2: line = prefix + "- {0} {1} ...".format(len(tasks), tasks[0]) lines.append(line) return "\n".join(lines) def _get_len_of_params(task): return sum(len(param[0]) for param in task.get_params()) def _get_str_ranging_multiple_parameters(first, last, tasks, unique_param): row = "" str_unique_param = "{0}...{1}".format(unique_param[1].serialize(first), unique_param[1].serialize(last)) for param in tasks[0].get_params(): row += "{0}=".format(param[0]) if param[0] == unique_param[0]: row += "{0}".format(str_unique_param) else: row += "{0}".format(param[1].serialize(getattr(tasks[0], param[0]))) if param != tasks[0].get_params()[-1]: row += ", " row += ")" return row def _get_set_of_params(tasks): params = {} for param in tasks[0].get_params(): params[param] = {getattr(task, param[0]) for task in tasks} return params def _get_unique_param_keys(params): for param_key, param_values in params.items(): if len(param_values) > 1: yield param_key def _ranging_attributes(attributes, param_class): """ Checks if there is a continuous range """ next_attributes = {param_class.next_in_enumeration(attribute) for attribute in attributes} in_first = attributes.difference(next_attributes) in_second = next_attributes.difference(attributes) if len(in_first) == 1 and len(in_second) == 1: for x in attributes: if {param_class.next_in_enumeration(x)} == in_second: return next(iter(in_first)), x return None, None def _get_str_one_parameter(tasks): row = "" count = 0 for task in tasks: if (len(row) >= 30 and count > 2 and count != len(tasks) - 1) or len(row) > 200: row += "..." break param = task.get_params()[0] row += "{0}".format(param[1].serialize(getattr(task, param[0]))) if count < len(tasks) - 1: row += "," count += 1 return row def _serialize_first_param(task): return task.get_params()[0][1].serialize(getattr(task, task.get_params()[0][0])) def _get_number_of_tasks_for(status, group_tasks): if status == "still_pending": return _get_number_of_tasks(group_tasks["still_pending_ext"]) + _get_number_of_tasks(group_tasks["still_pending_not_ext"]) return _get_number_of_tasks(group_tasks[status]) def _get_number_of_tasks(task_dict): return sum(len(tasks) for tasks in task_dict.values()) def _get_comments(group_tasks): """ Get the human readable comments and quantities for the task types. """ comments = {} for status, human in _COMMENTS: num_tasks = _get_number_of_tasks_for(status, group_tasks) if num_tasks: space = " " if status in _PENDING_SUB_STATUSES else "" comments[status] = "{space}* {num_tasks} {human}:\n".format(space=space, num_tasks=num_tasks, human=human) return comments # Oredered in the sense that they'll be printed in this order _ORDERED_STATUSES = ( "already_done", "completed", "ever_failed", "failed", "scheduling_error", "still_pending", "still_pending_ext", "run_by_other_worker", "upstream_failure", "upstream_missing_dependency", "upstream_run_by_other_worker", "upstream_scheduling_error", "not_run", ) _PENDING_SUB_STATUSES = set(_ORDERED_STATUSES[_ORDERED_STATUSES.index("still_pending_ext") :]) _COMMENTS = { ("already_done", "complete ones were encountered"), ("completed", "ran successfully"), ("failed", "failed"), ("scheduling_error", "failed scheduling"), ("still_pending", "were left pending, among these"), ("still_pending_ext", "were missing external dependencies"), ("run_by_other_worker", "were being run by another worker"), ("upstream_failure", "had failed dependencies"), ("upstream_missing_dependency", "had missing dependencies"), ("upstream_run_by_other_worker", "had dependencies that were being run by other worker"), ("upstream_scheduling_error", "had dependencies whose scheduling failed"), ("not_run", "was not granted run permission by the scheduler"), } def _get_run_by_other_worker(worker): """ This returns a set of the tasks that are being run by other worker """ task_sets = _get_external_workers(worker).values() return functools.reduce(lambda a, b: a | b, task_sets, set()) def _get_external_workers(worker): """ This returns a dict with a set of tasks for all of the other workers """ worker_that_blocked_task = collections.defaultdict(set) get_work_response_history = worker._get_work_response_history for get_work_response in get_work_response_history: if get_work_response["task_id"] is None: for running_task in get_work_response["running_tasks"]: other_worker_id = running_task["worker"] other_task_id = running_task["task_id"] other_task = worker._scheduled_tasks.get(other_task_id) if other_worker_id == worker._id or not other_task: continue worker_that_blocked_task[other_worker_id].add(other_task) return worker_that_blocked_task def _group_tasks_by_name_and_status(task_dict): """ Takes a dictionary with sets of tasks grouped by their status and returns a dictionary with dictionaries with an array of tasks grouped by their status and task name """ group_status = {} for task in task_dict: if task.task_family not in group_status: group_status[task.task_family] = [] group_status[task.task_family].append(task) return group_status def _summary_dict(worker): set_tasks = _partition_tasks(worker) set_tasks["run_by_other_worker"] = _get_run_by_other_worker(worker) _populate_unknown_statuses(set_tasks) return set_tasks def _summary_format(set_tasks, worker): group_tasks = {} for status, task_dict in set_tasks.items(): group_tasks[status] = _group_tasks_by_name_and_status(task_dict) comments = _get_comments(group_tasks) num_all_tasks = sum( [ len(set_tasks["already_done"]), len(set_tasks["completed"]), len(set_tasks["failed"]), len(set_tasks["scheduling_error"]), len(set_tasks["still_pending_ext"]), len(set_tasks["still_pending_not_ext"]), ] ) str_output = "" str_output += "Scheduled {0} tasks of which:\n".format(num_all_tasks) for status in _ORDERED_STATUSES: if status not in comments: continue str_output += "{0}".format(comments[status]) if status != "still_pending": str_output += "{0}\n".format(_get_str(group_tasks[status], status in _PENDING_SUB_STATUSES)) ext_workers = _get_external_workers(worker) group_tasks_ext_workers = {} for ext_worker, task_dict in ext_workers.items(): group_tasks_ext_workers[ext_worker] = _group_tasks_by_name_and_status(task_dict) if len(ext_workers) > 0: str_output += "\nThe other workers were:\n" count = 0 for ext_worker, task_dict in ext_workers.items(): if count > 3 and count < len(ext_workers) - 1: str_output += " and {0} other workers".format(len(ext_workers) - count) break str_output += " - {0} ran {1} tasks\n".format(ext_worker, len(task_dict)) count += 1 str_output += "\n" if num_all_tasks == sum( [len(set_tasks["already_done"]), len(set_tasks["scheduling_error"]), len(set_tasks["still_pending_ext"]), len(set_tasks["still_pending_not_ext"])] ): if len(ext_workers) == 0: str_output += "\n" str_output += "Did not run any tasks" one_line_summary = _create_one_line_summary(_tasks_status(set_tasks)) str_output += "\n{0}".format(one_line_summary) if num_all_tasks == 0: str_output = "Did not schedule any tasks" return str_output def _create_one_line_summary(status_code): """ Given a status_code of type LuigiStatusCode which has a tuple value, returns a one line summary """ return "This progress looks {0} because {1}".format(*status_code.value) def _tasks_status(set_tasks): """ Given a grouped set of tasks, returns a LuigiStatusCode """ if set_tasks["ever_failed"]: if not set_tasks["failed"]: return LuigiStatusCode.SUCCESS_WITH_RETRY else: if set_tasks["scheduling_error"]: return LuigiStatusCode.FAILED_AND_SCHEDULING_FAILED return LuigiStatusCode.FAILED elif set_tasks["scheduling_error"]: return LuigiStatusCode.SCHEDULING_FAILED elif set_tasks["not_run"]: return LuigiStatusCode.NOT_RUN elif set_tasks["still_pending_ext"]: return LuigiStatusCode.MISSING_EXT else: return LuigiStatusCode.SUCCESS def _summary_wrap(str_output): return textwrap.dedent(""" ===== Luigi Execution Summary ===== {str_output} ===== Luigi Execution Summary ===== """).format(str_output=str_output) def summary(worker): """ Given a worker, return a human readable summary of what the worker have done. """ return _summary_wrap(_summary_format(_summary_dict(worker), worker)) # 5 ================================================ FILE: luigi/format.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import io import locale import os import re import signal import subprocess import tempfile import warnings class FileWrapper: """ Wrap `file` in a "real" so stuff can be added to it after creation. """ def __init__(self, file_object): self._subpipe = file_object def __getattr__(self, name): # forward calls to 'write', 'close' and other methods not defined below return getattr(self._subpipe, name) def __enter__(self, *args, **kwargs): # instead of returning whatever is returned by __enter__ on the subpipe # this returns self, so whatever custom injected methods are still available # this might cause problems with custom file_objects, but seems to work # fine with standard python `file` objects which is the only default use return self def __exit__(self, *args, **kwargs): return self._subpipe.__exit__(*args, **kwargs) def __iter__(self): return iter(self._subpipe) class InputPipeProcessWrapper: def __init__(self, command, input_pipe=None): """ Initializes a InputPipeProcessWrapper instance. :param command: a subprocess.Popen instance with stdin=input_pipe and stdout=subprocess.PIPE. Alternatively, just its args argument as a convenience. """ self._command = command self._input_pipe = input_pipe self._original_input = True if input_pipe is not None: try: input_pipe.fileno() except (AttributeError, io.UnsupportedOperation): # subprocess require a fileno to work, if not present we copy to disk first self._original_input = False f = tempfile.NamedTemporaryFile("wb", prefix="luigi-process_tmp", delete=False) self._tmp_file = f.name while True: chunk = input_pipe.read(io.DEFAULT_BUFFER_SIZE) if not chunk: break f.write(chunk) input_pipe.close() f.close() self._input_pipe = FileWrapper(io.BufferedReader(io.FileIO(self._tmp_file, "r"))) self._process = command if isinstance(command, subprocess.Popen) else self.create_subprocess(command) # we want to keep a circular reference to avoid garbage collection # when the object is used in, e.g., pipe.read() self._process._selfref = self def create_subprocess(self, command): """ http://www.chiark.greenend.org.uk/ucgi/~cjwatson/blosxom/2009-07-02-python-sigpipe.html """ def subprocess_setup(): # Python installs a SIGPIPE handler by default. This is usually not what # non-Python subprocesses expect. signal.signal(signal.SIGPIPE, signal.SIG_DFL) return subprocess.Popen(command, stdin=self._input_pipe, stdout=subprocess.PIPE, preexec_fn=subprocess_setup, close_fds=True) def _finish(self): # Need to close this before input_pipe to get all SIGPIPE messages correctly self._process.stdout.close() if not self._original_input and os.path.exists(self._tmp_file): os.remove(self._tmp_file) if self._input_pipe is not None: self._input_pipe.close() self._process.wait() # deadlock? if self._process.returncode not in (0, 141, 128 - 141): # 141 == 128 + 13 == 128 + SIGPIPE - normally processes exit with 128 + {reiceived SIG} # 128 - 141 == -13 == -SIGPIPE, sometimes python receives -13 for some subprocesses raise RuntimeError("Error reading from pipe. Subcommand exited with non-zero exit status %s." % self._process.returncode) def close(self): self._finish() def __del__(self): self._finish() def __enter__(self): return self def _abort(self): """ Call _finish, but eat the exception (if any). """ try: self._finish() except KeyboardInterrupt: raise except BaseException: pass def __exit__(self, type, value, traceback): if type: self._abort() else: self._finish() def __getattr__(self, name): if name in ["_process", "_input_pipe"]: raise AttributeError(name) try: return getattr(self._process.stdout, name) except AttributeError: return getattr(self._input_pipe, name) def __iter__(self): for line in self._process.stdout: yield line self._finish() def readable(self): return True def writable(self): return False def seekable(self): return False class OutputPipeProcessWrapper: WRITES_BEFORE_FLUSH = 10000 def __init__(self, command, output_pipe=None): self.closed = False self._command = command self._output_pipe = output_pipe self._process = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=output_pipe, close_fds=True) self._flushcount = 0 def write(self, *args, **kwargs): self._process.stdin.write(*args, **kwargs) self._flushcount += 1 if self._flushcount == self.WRITES_BEFORE_FLUSH: self._process.stdin.flush() self._flushcount = 0 def writeLine(self, line): assert "\n" not in line self.write(line + "\n") def _finish(self): """ Closes and waits for subprocess to exit. """ if self._process.returncode is None: self._process.stdin.flush() self._process.stdin.close() self._process.wait() self.closed = True def __del__(self): if not self.closed: self.abort() def __exit__(self, type, value, traceback): if type is None: self.close() else: self.abort() def __enter__(self): return self def close(self): self._finish() if self._process.returncode == 0: if self._output_pipe is not None: self._output_pipe.close() else: raise RuntimeError("Error when executing command %s" % self._command) def abort(self): self._finish() def __getattr__(self, name): if name in ["_process", "_output_pipe"]: raise AttributeError(name) try: return getattr(self._process.stdin, name) except AttributeError: return getattr(self._output_pipe, name) def readable(self): return False def writable(self): return True def seekable(self): return False class BaseWrapper: def __init__(self, stream, *args, **kwargs): self._stream = stream try: super(BaseWrapper, self).__init__(stream, *args, **kwargs) except TypeError: pass def __getattr__(self, name): if name == "_stream": raise AttributeError(name) return getattr(self._stream, name) def __enter__(self): self._stream.__enter__() return self def __exit__(self, *args): self._stream.__exit__(*args) def __iter__(self): try: for line in self._stream: yield line finally: self.close() class NewlineWrapper(BaseWrapper): def __init__(self, stream, newline=None): if newline is None: self.newline = newline else: self.newline = newline.encode("ascii") if self.newline not in (b"", b"\r\n", b"\n", b"\r", None): raise ValueError("newline need to be one of {b'', b'\r\n', b'\n', b'\r', None}") super(NewlineWrapper, self).__init__(stream) def read(self, n=-1): b = self._stream.read(n) if self.newline == b"": return b if self.newline is None: newline = b"\n" return re.sub(b"(\n|\r\n|\r)", newline, b) def writelines(self, lines): if self.newline is None or self.newline == "": newline = os.linesep.encode("ascii") else: newline = self.newline self._stream.writelines((re.sub(b"(\n|\r\n|\r)", newline, line) for line in lines)) def write(self, b): if self.newline is None or self.newline == "": newline = os.linesep.encode("ascii") else: newline = self.newline self._stream.write(re.sub(b"(\n|\r\n|\r)", newline, b)) class MixedUnicodeBytesWrapper(BaseWrapper): """ """ def __init__(self, stream, encoding=None): if encoding is None: encoding = locale.getpreferredencoding() self.encoding = encoding super(MixedUnicodeBytesWrapper, self).__init__(stream) def write(self, b): self._stream.write(self._convert(b)) def writelines(self, lines): self._stream.writelines((self._convert(line) for line in lines)) def _convert(self, b): if isinstance(b, str): b = b.encode(self.encoding) warnings.warn("Writing unicode to byte stream", stacklevel=2) return b class Format: """ Interface for format specifications. """ @classmethod def pipe_reader(cls, input_pipe): raise NotImplementedError() @classmethod def pipe_writer(cls, output_pipe): raise NotImplementedError() def __rshift__(a, b): return ChainFormat(a, b) class ChainFormat(Format): def __init__(self, *args, **kwargs): self.args = args try: self.input = args[0].input except AttributeError: pass try: self.output = args[-1].output except AttributeError: pass if not kwargs.get("check_consistency", True): return for x in range(len(args) - 1): try: if args[x].output != args[x + 1].input: raise TypeError( "The format chaining is not valid, %s expect %s" " but %s provide %s" % ( args[x + 1].__class__.__name__, args[x + 1].input, args[x].__class__.__name__, args[x].output, ) ) except AttributeError: pass def pipe_reader(self, input_pipe): for x in reversed(self.args): input_pipe = x.pipe_reader(input_pipe) return input_pipe def pipe_writer(self, output_pipe): for x in reversed(self.args): output_pipe = x.pipe_writer(output_pipe) return output_pipe class TextWrapper(io.TextIOWrapper): def __exit__(self, *args): # io.TextIOWrapper close the file on __exit__, let the underlying file decide if not self.closed and self.writable(): super(TextWrapper, self).flush() self._stream.__exit__(*args) def __del__(self, *args): # io.TextIOWrapper close the file on __del__, let the underlying file decide if not self.closed and self.writable(): super(TextWrapper, self).flush() try: self._stream.__del__(*args) except AttributeError: pass def __init__(self, stream, *args, **kwargs): self._stream = stream try: super(TextWrapper, self).__init__(stream, *args, **kwargs) except TypeError: pass def __getattr__(self, name): if name == "_stream": raise AttributeError(name) return getattr(self._stream, name) def __enter__(self): self._stream.__enter__() return self class NopFormat(Format): def pipe_reader(self, input_pipe): return input_pipe def pipe_writer(self, output_pipe): return output_pipe class WrappedFormat(Format): def __init__(self, *args, **kwargs): self.args = args self.kwargs = kwargs def pipe_reader(self, input_pipe): return self.wrapper_cls(input_pipe, *self.args, **self.kwargs) def pipe_writer(self, output_pipe): return self.wrapper_cls(output_pipe, *self.args, **self.kwargs) class TextFormat(WrappedFormat): input = "unicode" output = "bytes" wrapper_cls = TextWrapper class MixedUnicodeBytesFormat(WrappedFormat): output = "bytes" wrapper_cls = MixedUnicodeBytesWrapper class NewlineFormat(WrappedFormat): input = "bytes" output = "bytes" wrapper_cls = NewlineWrapper class GzipFormat(Format): input = "bytes" output = "bytes" def __init__(self, compression_level=None): self.compression_level = compression_level def pipe_reader(self, input_pipe): return InputPipeProcessWrapper(["gunzip"], input_pipe) def pipe_writer(self, output_pipe): args = ["gzip"] if self.compression_level is not None: args.append("-" + str(int(self.compression_level))) return OutputPipeProcessWrapper(args, output_pipe) class Bzip2Format(Format): input = "bytes" output = "bytes" def pipe_reader(self, input_pipe): return InputPipeProcessWrapper(["bzcat"], input_pipe) def pipe_writer(self, output_pipe): return OutputPipeProcessWrapper(["bzip2"], output_pipe) Text = TextFormat() UTF8 = TextFormat(encoding="utf8") Nop = NopFormat() SysNewLine = NewlineFormat() Gzip = GzipFormat() Bzip2 = Bzip2Format() MixedUnicodeBytes = MixedUnicodeBytesFormat() def get_default_format(): return Text ================================================ FILE: luigi/freezing.py ================================================ """Internal-only module with immutable data structures. Please, do not use it outside of Luigi codebase itself. """ from collections import OrderedDict try: from collections.abc import Mapping except ImportError: from collections import Mapping # type: ignore import functools import operator class FrozenOrderedDict(Mapping): """ It is an immutable wrapper around ordered dictionaries that implements the complete :py:class:`collections.Mapping` interface. It can be used as a drop-in replacement for dictionaries where immutability and ordering are desired. """ def __init__(self, *args, **kwargs): self.__dict = OrderedDict(*args, **kwargs) self.__hash = None def __getitem__(self, key): return self.__dict[key] def __iter__(self): return iter(self.__dict) def __len__(self): return len(self.__dict) def __repr__(self): # We should use short representation for beautiful console output return repr(dict(self.__dict)) def __hash__(self): if self.__hash is None: hashes = map(hash, self.items()) self.__hash = functools.reduce(operator.xor, hashes, 0) return self.__hash def get_wrapped(self): return self.__dict def recursively_freeze(value): """ Recursively walks ``Mapping``s and ``list``s and converts them to ``FrozenOrderedDict`` and ``tuples``, respectively. """ if isinstance(value, Mapping): return FrozenOrderedDict(((k, recursively_freeze(v)) for k, v in value.items())) elif isinstance(value, list) or isinstance(value, tuple): return tuple(recursively_freeze(v) for v in value) return value def recursively_unfreeze(value): """ Recursively walks ``FrozenOrderedDict``s and ``tuple``s and converts them to ``dict`` and ``list``, respectively. """ if isinstance(value, Mapping): return dict(((k, recursively_unfreeze(v)) for k, v in value.items())) elif isinstance(value, list) or isinstance(value, tuple): return list(recursively_unfreeze(v) for v in value) return value ================================================ FILE: luigi/interface.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ This module contains the bindings for command line integration and dynamic loading of tasks If you don't want to run luigi from the command line. You may use the methods defined in this module to programmatically run luigi. """ import logging import os import signal import sys import tempfile import warnings from luigi import lock, parameter, rpc, scheduler, task, worker from luigi.cmdline_parser import CmdlineParser from luigi.execution_summary import LuigiRunResult from luigi.setup_logging import InterfaceLogging class core(task.Config): """Keeps track of a bunch of environment params. Uses the internal luigi parameter mechanism. The nice thing is that we can instantiate this class and get an object with all the environment variables set. This is arguably a bit of a hack. """ use_cmdline_section = False ignore_unconsumed = { "autoload_range", "no_configure_logging", } local_scheduler = parameter.BoolParameter(default=False, description="Use an in-memory central scheduler. Useful for testing.", always_in_help=True) scheduler_host = parameter.Parameter( default="localhost", description="Hostname of machine running remote scheduler", config_path=dict(section="core", name="default-scheduler-host") ) scheduler_port = parameter.IntParameter( default=8082, description="Port of remote scheduler api process", config_path=dict(section="core", name="default-scheduler-port") ) scheduler_url = parameter.Parameter( default="", description="Full path to remote scheduler", config_path=dict(section="core", name="default-scheduler-url"), ) lock_size = parameter.IntParameter(default=1, description="Maximum number of workers running the same command") no_lock = parameter.BoolParameter(default=False, description="Ignore if similar process is already running") lock_pid_dir = parameter.Parameter(default=os.path.join(tempfile.gettempdir(), "luigi"), description="Directory to store the pid file") take_lock = parameter.BoolParameter(default=False, description="Signal other processes to stop getting work if already running") workers = parameter.IntParameter(default=1, description="Maximum number of parallel tasks to run") logging_conf_file = parameter.Parameter(default="", description="Configuration file for logging") log_level = parameter.ChoiceParameter( default="DEBUG", choices=["NOTSET", "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], description="Default log level to use when logging_conf_file is not set", ) module = parameter.Parameter(default="", description="Used for dynamic loading of modules", always_in_help=True) parallel_scheduling = parameter.BoolParameter(default=False, description="Use multiprocessing to do scheduling in parallel.") parallel_scheduling_processes = parameter.IntParameter( default=0, description="The number of processes to use for scheduling in parallel. By default the number of available CPUs will be used" ) assistant = parameter.BoolParameter(default=False, description="Run any task from the scheduler.") help = parameter.BoolParameter(default=False, description="Show most common flags and all task-specific flags", always_in_help=True) help_all = parameter.BoolParameter(default=False, description="Show all command line flags", always_in_help=True) class _WorkerSchedulerFactory: def create_local_scheduler(self): return scheduler.Scheduler(prune_on_get_work=True, record_task_history=False) def create_remote_scheduler(self, url): return rpc.RemoteScheduler(url) def create_worker(self, scheduler, worker_processes, assistant=False): return worker.Worker(scheduler=scheduler, worker_processes=worker_processes, assistant=assistant) def _schedule_and_run(tasks, worker_scheduler_factory=None, override_defaults=None): """ :param tasks: :param worker_scheduler_factory: :param override_defaults: :return: True if all tasks and their dependencies were successfully run (or already completed); False if any error occurred. It will return a detailed response of type LuigiRunResult instead of a boolean if detailed_summary=True. """ if worker_scheduler_factory is None: worker_scheduler_factory = _WorkerSchedulerFactory() if override_defaults is None: override_defaults = {} env_params = core(**override_defaults) InterfaceLogging.setup(env_params) kill_signal = signal.SIGUSR1 if env_params.take_lock else None if not env_params.no_lock and not (lock.acquire_for(env_params.lock_pid_dir, env_params.lock_size, kill_signal)): raise PidLockAlreadyTakenExit() if env_params.local_scheduler: sch = worker_scheduler_factory.create_local_scheduler() else: if env_params.scheduler_url != "": url = env_params.scheduler_url else: url = "http://{host}:{port:d}/".format( host=env_params.scheduler_host, port=env_params.scheduler_port, ) sch = worker_scheduler_factory.create_remote_scheduler(url=url) worker = worker_scheduler_factory.create_worker(scheduler=sch, worker_processes=env_params.workers, assistant=env_params.assistant) success = True logger = logging.getLogger("luigi-interface") with worker: for t in tasks: success &= worker.add(t, env_params.parallel_scheduling, env_params.parallel_scheduling_processes) logger.info("Done scheduling tasks") success &= worker.run() luigi_run_result = LuigiRunResult(worker, success) logger.info(luigi_run_result.summary_text) if hasattr(sch, "close"): sch.close() return luigi_run_result class PidLockAlreadyTakenExit(SystemExit): """ The exception thrown by :py:func:`luigi.run`, when the lock file is inaccessible """ pass def run(*args, **kwargs): """ Please dont use. Instead use `luigi` binary. Run from cmdline using argparse. :param use_dynamic_argparse: Deprecated and ignored """ luigi_run_result = _run(*args, **kwargs) return luigi_run_result if kwargs.get("detailed_summary") else luigi_run_result.scheduling_succeeded def _run(cmdline_args=None, main_task_cls=None, worker_scheduler_factory=None, use_dynamic_argparse=None, local_scheduler=False, detailed_summary=False): if use_dynamic_argparse is not None: warnings.warn("use_dynamic_argparse is deprecated, don't set it.", DeprecationWarning, stacklevel=2) if cmdline_args is None: cmdline_args = sys.argv[1:] if main_task_cls: cmdline_args.insert(0, main_task_cls.task_family) if local_scheduler: cmdline_args.append("--local-scheduler") with CmdlineParser.global_instance(cmdline_args) as cp: return _schedule_and_run([cp.get_task_obj()], worker_scheduler_factory) def build(tasks, worker_scheduler_factory=None, detailed_summary=False, **env_params): """ Run internally, bypassing the cmdline parsing. Useful if you have some luigi code that you want to run internally. Example: .. code-block:: python luigi.build([MyTask1(), MyTask2()], local_scheduler=True) One notable difference is that `build` defaults to not using the identical process lock. Otherwise, `build` would only be callable once from each process. :param tasks: :param worker_scheduler_factory: :param env_params: :return: True if there were no scheduling errors, even if tasks may fail. """ if "no_lock" not in env_params: env_params["no_lock"] = True luigi_run_result = _schedule_and_run(tasks, worker_scheduler_factory, override_defaults=env_params) return luigi_run_result if detailed_summary else luigi_run_result.scheduling_succeeded ================================================ FILE: luigi/local_target.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ :class:`LocalTarget` provides a concrete implementation of a :py:class:`~luigi.target.Target` class that uses files on the local file system """ import errno import io import os import random import shutil import tempfile import warnings from luigi.format import FileWrapper, get_default_format from luigi.target import AtomicLocalFile, FileAlreadyExists, FileSystem, FileSystemTarget, MissingParentDirectory, NotADirectory class atomic_file(AtomicLocalFile): """Simple class that writes to a temp file and moves it on close() Also cleans up the temp file if close is not invoked """ def move_to_final_destination(self): os.replace(self.tmp_path, self.path) def generate_tmp_path(self, path): return path + "-luigi-tmp-%09d" % random.randrange(0, 10_000_000_000) class LocalFileSystem(FileSystem): """ Wrapper for access to file system operations. Work in progress - add things as needed. """ def copy(self, old_path, new_path, raise_if_exists=False): if raise_if_exists and os.path.exists(new_path): raise RuntimeError("Destination exists: %s" % new_path) d = os.path.dirname(new_path) if d and not os.path.exists(d): self.mkdir(d) shutil.copy(old_path, new_path) def exists(self, path): return os.path.exists(path) def mkdir(self, path, parents=True, raise_if_exists=False): if self.exists(path): if raise_if_exists: raise FileAlreadyExists() elif not self.isdir(path): raise NotADirectory() else: return if parents: try: os.makedirs(path) except OSError as err: # somebody already created the path if err.errno != errno.EEXIST: raise else: if not os.path.exists(os.path.dirname(path)): raise MissingParentDirectory() os.mkdir(path) def isdir(self, path): return os.path.isdir(path) def listdir(self, path): for dir_, _, files in os.walk(path): assert dir_.startswith(path) for name in files: yield os.path.join(dir_, name) def remove(self, path, recursive=True): if recursive and self.isdir(path): shutil.rmtree(path) else: os.remove(path) def move(self, old_path, new_path, raise_if_exists=False): """ Move file atomically. If source and destination are located on different filesystems, atomicity is approximated but cannot be guaranteed. """ if raise_if_exists and os.path.exists(new_path): raise FileAlreadyExists("Destination exists: %s" % new_path) d = os.path.dirname(new_path) if d and not os.path.exists(d): self.mkdir(d) try: os.replace(old_path, new_path) except OSError as err: if err.errno == errno.EXDEV: new_path_tmp = "%s-%09d" % (new_path, random.randint(0, 999999999)) shutil.copy(old_path, new_path_tmp) os.replace(new_path_tmp, new_path) os.remove(old_path) else: raise err def rename_dont_move(self, path, dest): """ Rename ``path`` to ``dest``, but don't move it into the ``dest`` folder (if it is a folder). This method is just a wrapper around the ``move`` method of LocalTarget. """ self.move(path, dest, raise_if_exists=True) class LocalTarget(FileSystemTarget): fs = LocalFileSystem() def __init__(self, path=None, format=None, is_tmp=False): if format is None: format = get_default_format() if not path: if not is_tmp: raise Exception("path or is_tmp must be set") path = os.path.join(tempfile.gettempdir(), "luigi-tmp-%09d" % random.randint(0, 999999999)) super(LocalTarget, self).__init__(path) self.format = format self.is_tmp = is_tmp def makedirs(self): """ Create all parent folders if they do not exist. """ normpath = os.path.normpath(self.path) parentfolder = os.path.dirname(normpath) if parentfolder: try: os.makedirs(parentfolder) except OSError: pass def open(self, mode="r"): rwmode = mode.replace("b", "").replace("t", "") if rwmode == "w": self.makedirs() return self.format.pipe_writer(atomic_file(self.path)) elif rwmode == "r": fileobj = FileWrapper(io.BufferedReader(io.FileIO(self.path, mode))) return self.format.pipe_reader(fileobj) else: raise Exception("mode must be 'r' or 'w' (got: %s)" % mode) def move(self, new_path, raise_if_exists=False): self.fs.move(self.path, new_path, raise_if_exists=raise_if_exists) def move_dir(self, new_path): self.move(new_path) def remove(self): self.fs.remove(self.path) def copy(self, new_path, raise_if_exists=False): self.fs.copy(self.path, new_path, raise_if_exists) @property def fn(self): warnings.warn("Use LocalTarget.path to reference filename", DeprecationWarning, stacklevel=2) return self.path def __del__(self): if hasattr(self, "is_tmp") and self.is_tmp and self.exists(): self.remove() ================================================ FILE: luigi/lock.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ Locking functionality when launching things from the command line. Uses a pidfile. This prevents multiple identical workflows to be launched simultaneously. """ import errno import hashlib import os import sys from subprocess import PIPE, Popen def getpcmd(pid): """ Returns command of process. :param pid: """ if os.name == "nt": # Use wmic command instead of ps on Windows. cmd = "wmic path win32_process where ProcessID=%s get Commandline 2> nul" % (pid,) with os.popen(cmd, "r") as p: lines = [line for line in p.readlines() if line.strip("\r\n ") != ""] if lines: _, val = lines return val elif sys.platform == "darwin": # Use pgrep instead of /proc on macOS. pidfile = ".%d.pid" % (pid,) with open(pidfile, "w") as f: f.write(str(pid)) try: p = Popen(["pgrep", "-lf", "-F", pidfile], stdout=PIPE) stdout, _ = p.communicate() line = stdout.decode("utf8").strip() if line: _, scmd = line.split(" ", 1) return scmd finally: os.unlink(pidfile) else: # Use the /proc filesystem # At least on android there have been some issues with not all # process infos being readable. In these cases using the `ps` command # worked. See the pull request at # https://github.com/spotify/luigi/pull/1876 try: with open("/proc/{0}/cmdline".format(pid), "r") as fh: return fh.read().replace("\0", " ").rstrip() except IOError: # the system may not allow reading the command line # of a process owned by another user pass # Fallback instead of None, for e.g. Cygwin where -o is an "unknown option" for the ps command: return "[PROCESS_WITH_PID={}]".format(pid) def get_info(pid_dir, my_pid=None): # Check the name and pid of this process if my_pid is None: my_pid = os.getpid() my_cmd = getpcmd(my_pid) cmd_hash = my_cmd.encode("utf8") pid_file = os.path.join(pid_dir, hashlib.new("md5", cmd_hash, usedforsecurity=False).hexdigest()) + ".pid" return my_pid, my_cmd, pid_file def acquire_for(pid_dir, num_available=1, kill_signal=None): """ Makes sure the process is only run once at the same time with the same name. Notice that we since we check the process name, different parameters to the same command can spawn multiple processes at the same time, i.e. running "/usr/bin/my_process" does not prevent anyone from launching "/usr/bin/my_process --foo bar". """ my_pid, my_cmd, pid_file = get_info(pid_dir) # Create a pid file if it does not exist try: os.mkdir(pid_dir) os.chmod(pid_dir, 0o700) except OSError as exc: if exc.errno != errno.EEXIST: raise pass # Let variable "pids" be all pids who exist in the .pid-file who are still # about running the same command. pids = {pid for pid in _read_pids_file(pid_file) if getpcmd(pid) == my_cmd} if kill_signal is not None: for pid in pids: os.kill(pid, kill_signal) print("Sent kill signal to Pids: {}".format(pids)) # We allow for the killer to progress, yet we don't want these to stack # up! So we only allow it once. num_available += 1 if len(pids) >= num_available: # We are already running under a different pid print("Pid(s) {} already running".format(pids)) if kill_signal is not None: print('Note: There have (probably) been 1 other "--take-lock" process which continued to run! Probably no need to run this one as well.') return False _write_pids_file(pid_file, pids | {my_pid}) return True def _read_pids_file(pid_file): # First setup a python 2 vs 3 compatibility # http://stackoverflow.com/a/21368622/621449 try: FileNotFoundError # noqa: F823 except NameError: # Should only happen on python 2 FileNotFoundError = IOError # If the file happen to not exist, simply return # an empty set() try: with open(pid_file, "r") as f: return {int(pid_str.strip()) for pid_str in f if pid_str.strip()} except FileNotFoundError: return set() def _write_pids_file(pid_file, pids_set): with open(pid_file, "w") as f: f.writelines("{}\n".format(pid) for pid in pids_set) # Make the .pid-file writable by all (when the os allows for it) if os.name != "nt": s = os.stat(pid_file) if os.getuid() == s.st_uid: os.chmod(pid_file, s.st_mode | 0o777) ================================================ FILE: luigi/metrics.py ================================================ import abc import importlib from enum import Enum class MetricsCollectors(Enum): custom = -1 default = 1 none = 1 datadog = 2 prometheus = 3 @classmethod def get(cls, which, custom_import=None): if which == MetricsCollectors.none: return NoMetricsCollector() elif which == MetricsCollectors.datadog: from luigi.contrib.datadog_metric import DatadogMetricsCollector return DatadogMetricsCollector() elif which == MetricsCollectors.prometheus: from luigi.contrib.prometheus_metric import PrometheusMetricsCollector return PrometheusMetricsCollector() elif which == MetricsCollectors.custom: if custom_import is None: raise ValueError(f"MetricsCollectors value ' {which} ' is -1 and custom_import is None") split_import_string = custom_import.split(".") import_path = ".".join(split_import_string[:-1]) import_class_string = split_import_string[-1] mod = importlib.import_module(import_path) metrics_class = getattr(mod, import_class_string) if issubclass(metrics_class, MetricsCollector): return metrics_class() else: raise ValueError(f"Custom Import: {custom_import} is not a subclass of MetricsCollector") else: raise ValueError("MetricsCollectors value ' {0} ' isn't supported", which) class MetricsCollector(metaclass=abc.ABCMeta): """Abstractable MetricsCollector base class that can be replace by tool specific implementation. """ @abc.abstractmethod def __init__(self): pass @abc.abstractmethod def handle_task_started(self, task): pass @abc.abstractmethod def handle_task_failed(self, task): pass @abc.abstractmethod def handle_task_disabled(self, task, config): pass @abc.abstractmethod def handle_task_done(self, task): pass def handle_task_statistics(self, task, statistics): pass def generate_latest(self): return def configure_http_handler(self, http_handler): pass class NoMetricsCollector(MetricsCollector): """Empty MetricsCollector when no collector is being used""" def __init__(self): pass def handle_task_started(self, task): pass def handle_task_failed(self, task): pass def handle_task_disabled(self, task, config): pass def handle_task_done(self, task): pass ================================================ FILE: luigi/mock.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ This module provides a class :class:`MockTarget`, an implementation of :py:class:`~luigi.target.Target`. :class:`MockTarget` contains all data in-memory. The main purpose is unit testing workflows without writing to disk. """ import multiprocessing import sys from io import BytesIO from luigi import target from luigi.format import get_default_format class MockFileSystem(target.FileSystem): """ MockFileSystem inspects/modifies _data to simulate file system operations. """ _data = None def copy(self, path, dest, raise_if_exists=False): """ Copies the contents of a single file path to dest """ if raise_if_exists and dest in self.get_all_data(): raise RuntimeError("Destination exists: %s" % path) contents = self.get_all_data()[path] self.get_all_data()[dest] = contents def get_all_data(self): # This starts a server in the background, so we don't want to do it in the global scope if MockFileSystem._data is None: MockFileSystem._data = multiprocessing.Manager().dict() return MockFileSystem._data def get_data(self, fn): return self.get_all_data()[fn] def exists(self, path): return MockTarget(path).exists() def remove(self, path, recursive=True, skip_trash=True): """ Removes the given mockfile. skip_trash doesn't have any meaning. """ if recursive: to_delete = [] for s in self.get_all_data().keys(): if s.startswith(path): to_delete.append(s) for s in to_delete: self.get_all_data().pop(s) else: self.get_all_data().pop(path) def move(self, path, dest, raise_if_exists=False): """ Moves a single file from path to dest """ if raise_if_exists and dest in self.get_all_data(): raise RuntimeError("Destination exists: %s" % path) contents = self.get_all_data().pop(path) self.get_all_data()[dest] = contents def listdir(self, path): """ listdir does a prefix match of self.get_all_data(), but doesn't yet support globs. """ return [s for s in self.get_all_data().keys() if s.startswith(path)] def isdir(self, path): return any(self.listdir(path)) def mkdir(self, path, parents=True, raise_if_exists=False): """ mkdir is a noop. """ pass def clear(self): self.get_all_data().clear() class MockTarget(target.FileSystemTarget): fs = MockFileSystem() def __init__(self, fn, is_tmp=None, mirror_on_stderr=False, format=None): self._mirror_on_stderr = mirror_on_stderr self.path = fn self.format = format or get_default_format() def exists( self, ): return self.path in self.fs.get_all_data() def move(self, path, raise_if_exists=False): """ Call MockFileSystem's move command """ self.fs.move(self.path, path, raise_if_exists) def rename(self, *args, **kwargs): """ Call move to rename self """ self.move(*args, **kwargs) def open(self, mode="r"): fn = self.path mock_target = self class Buffer(BytesIO): # Just to be able to do writing + reading from the same buffer _write_line = True def set_wrapper(self, wrapper): self.wrapper = wrapper def write(self, data): if mock_target._mirror_on_stderr: if self._write_line: sys.stderr.write(fn + ": ") if isinstance(data, bytes): sys.stderr.write(data.decode("utf8")) else: sys.stderr.write(data) if (data[-1]) == "\n": self._write_line = True else: self._write_line = False super(Buffer, self).write(data) def close(self): if mode[0] == "w": try: mock_target.wrapper.flush() except AttributeError: pass mock_target.fs.get_all_data()[fn] = self.getvalue() super(Buffer, self).close() def __exit__(self, exc_type, exc_val, exc_tb): if not exc_type: self.close() def __enter__(self): return self def readable(self): return mode[0] == "r" def writeable(self): return mode[0] == "w" def seekable(self): return False if mode[0] == "w": wrapper = self.format.pipe_writer(Buffer()) wrapper.set_wrapper(wrapper) return wrapper else: return self.format.pipe_reader(Buffer(self.fs.get_all_data()[fn])) ================================================ FILE: luigi/mypy.py ================================================ """Plugin that provides support for luigi.Task This Code reuses the code from mypy.plugins.dataclasses https://github.com/python/mypy/blob/0753e2a82dad35034e000609b6e8daa37238bfaa/mypy/plugins/dataclasses.py """ from __future__ import annotations from typing import Callable, Dict, Final, Iterator, List, Literal, Optional from mypy.expandtype import expand_type, expand_type_by_instance from mypy.nodes import ( ARG_NAMED_OPT, ARG_POS, Argument, AssignmentStmt, Block, CallExpr, ClassDef, Context, EllipsisExpr, Expression, FuncDef, IfStmt, JsonDict, MemberExpr, NameExpr, PlaceholderNode, RefExpr, Statement, SymbolTableNode, TempNode, TypeInfo, Var, ) from mypy.plugin import ( ClassDefContext, FunctionContext, Plugin, SemanticAnalyzerPluginInterface, ) from mypy.plugins.common import ( add_method_to_class, deserialize_and_fixup_type, ) from mypy.server.trigger import make_wildcard_trigger from mypy.state import state from mypy.typeops import map_type_from_supertype from mypy.types import ( AnyType, CallableType, Instance, NoneType, Type, TypeOfAny, get_proper_type, ) from mypy.typevars import fill_typevars METADATA_TAG: Final[str] = "task" class TaskPlugin(Plugin): def get_base_class_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None: sym = self.lookup_fully_qualified(fullname) if sym and isinstance(sym.node, TypeInfo): if any(base.fullname == "luigi.task.Task" for base in sym.node.mro): return self._task_class_maker_callback return None def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] | None: """Adjust the return type of the `Parameters` function.""" if self.check_parameter(fullname): return self._task_parameter_field_callback return None def check_parameter(self, fullname): sym = self.lookup_fully_qualified(fullname) if sym and isinstance(sym.node, TypeInfo): return any(base.fullname == "luigi.parameter.Parameter" for base in sym.node.mro) def _task_class_maker_callback(self, ctx: ClassDefContext) -> None: transformer = TaskTransformer(ctx.cls, ctx.reason, ctx.api, self) transformer.transform() def _infer_choice_enum_element_type(self, ctx: FunctionContext, default_type: Instance) -> Type: """Infer the element type for Choice/Enum parameter variants. Checks the type argument first, then falls back to the 'choices' kwarg. """ element_type: Type = default_type.args[0] if default_type.args else AnyType(TypeOfAny.unannotated) for i, names in enumerate(ctx.arg_names): for j, name in enumerate(names): if name == "choices": choices_type = get_proper_type(ctx.arg_types[i][j]) if isinstance(choices_type, Instance) and choices_type.args: element_type = choices_type.args[0] return element_type def _task_parameter_field_callback(self, ctx: FunctionContext) -> Type: """Extract the type of the `default` argument from the Field function, and use it as the return type. In particular: * Retrieve the type of the argument which is specified, and use it as return type for the function. * If no default argument is specified, use the __new__ method's return type from the Parameter class e.g. ```python foo: int = luigi.IntParameter() # IntParameter.__new__ returns int ``` """ # Try to get the return type from __new__ method default_type = ctx.default_return_type if isinstance(default_type, Instance): # Handle Choice/Enum list parameters (ChoiceListParameter, EnumListParameter) if default_type.type.fullname in ( "luigi.parameter.ChoiceListParameter", "luigi.parameter.EnumListParameter", ): element_type = self._infer_choice_enum_element_type(ctx, default_type) return ctx.api.named_generic_type("builtins.tuple", [element_type]) # Handle Choice/Enum scalar parameters (ChoiceParameter, EnumParameter) if default_type.type.fullname in ( "luigi.parameter.ChoiceParameter", "luigi.parameter.EnumParameter", ): return self._infer_choice_enum_element_type(ctx, default_type) # Check if a 'default' argument is explicitly provided try: default_idx = ctx.callee_arg_names.index("default") if ctx.args[default_idx]: default_arg = ctx.args[default_idx][0] if not isinstance(default_arg, EllipsisExpr): return ctx.arg_types[default_idx][0] except ValueError: pass # For Parameter subclasses without explicit default, return Any # so that both annotation styles work: # foo: int = IntParameter() (resolved type annotation) # foo: IntParameter = IntParameter() (parameter type annotation) return AnyType(TypeOfAny.special_form) try: default_idx = ctx.callee_arg_names.index("default") except ValueError: return AnyType(TypeOfAny.unannotated) default_args = ctx.args[default_idx] if default_args: default_type = ctx.arg_types[default_idx][0] default_arg = default_args[0] if not isinstance(default_arg, EllipsisExpr): return default_type return AnyType(TypeOfAny.unannotated) class TaskAttribute: def __init__( self, name: str, has_default: bool, line: int, column: int, type: Type | None, info: TypeInfo, api: SemanticAnalyzerPluginInterface, ) -> None: self.name = name self.has_default = has_default self.line = line self.column = column self.type = type # Type as __init__ argument self.info = info self._api = api def to_argument(self, current_info: TypeInfo, *, of: Literal["__init__",]) -> Argument: if of == "__init__": # All arguments to __init__ are keyword-only and optional # This is because gokart can set parameters by configuration' arg_kind = ARG_NAMED_OPT return Argument( variable=self.to_var(current_info), type_annotation=self.expand_type(current_info), initializer=EllipsisExpr() if self.has_default else None, # Only used by stubgen kind=arg_kind, ) def expand_type(self, current_info: TypeInfo) -> Type | None: if self.type is not None and self.info.self_type is not None: # In general, it is not safe to call `expand_type()` during semantic analysis, # however this plugin is called very late, so all types should be fully ready. # Also, it is tricky to avoid eager expansion of Self types here (e.g. because # we serialize attributes). with state.strict_optional_set(self._api.options.strict_optional): return expand_type(self.type, {self.info.self_type.id: fill_typevars(current_info)}) return self.type def to_var(self, current_info: TypeInfo) -> Var: return Var(self.name, self.expand_type(current_info)) def serialize(self) -> JsonDict: assert self.type return { "name": self.name, "has_default": self.has_default, "line": self.line, "column": self.column, "type": self.type.serialize(), } @classmethod def deserialize(cls, info: TypeInfo, data: JsonDict, api: SemanticAnalyzerPluginInterface) -> TaskAttribute: data = data.copy() typ = deserialize_and_fixup_type(data.pop("type"), api) return cls(type=typ, info=info, **data, api=api) def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None: """Expands type vars in the context of a subtype when an attribute is inherited from a generic super type.""" if self.type is not None: with state.strict_optional_set(self._api.options.strict_optional): self.type = map_type_from_supertype(self.type, sub_type, self.info) class TaskTransformer: """Implement the behavior of gokart.Task.""" def __init__( self, cls: ClassDef, reason: Expression | Statement, api: SemanticAnalyzerPluginInterface, task_plugin: TaskPlugin, ) -> None: self._cls = cls self._reason = reason self._api = api self._task_plugin = task_plugin def transform(self) -> bool: """Apply all the necessary transformations to the underlying gokart.Task""" info = self._cls.info attributes = self.collect_attributes() if attributes is None: # Some definitions are not ready. We need another pass. return False for attr in attributes: if attr.type is None: return False # If there are no attributes, it may be that the semantic analyzer has not # processed them yet. In order to work around this, we can simply skip generating # __init__ if there are no attributes, because if the user truly did not define any, # then the object default __init__ with an empty signature will be present anyway. if ("__init__" not in info.names or info.names["__init__"].plugin_generated) and attributes: args = [attr.to_argument(info, of="__init__") for attr in attributes] add_method_to_class(self._api, self._cls, "__init__", args=args, return_type=NoneType()) info.metadata[METADATA_TAG] = { "attributes": [attr.serialize() for attr in attributes], } return True def _get_assignment_statements_from_if_statement(self, stmt: IfStmt) -> Iterator[AssignmentStmt]: for body in stmt.body: if not body.is_unreachable: yield from self._get_assignment_statements_from_block(body) if stmt.else_body is not None and not stmt.else_body.is_unreachable: yield from self._get_assignment_statements_from_block(stmt.else_body) def _get_assignment_statements_from_block(self, block: Block) -> Iterator[AssignmentStmt]: for stmt in block.body: if isinstance(stmt, AssignmentStmt): yield stmt elif isinstance(stmt, IfStmt): yield from self._get_assignment_statements_from_if_statement(stmt) def collect_attributes(self) -> Optional[List[TaskAttribute]]: """Collect all attributes declared in the task and its parents. All assignments of the form a: SomeType b: SomeOtherType = ... are collected. Return None if some base class hasn't been processed yet and thus we'll need to ask for another pass. """ cls = self._cls # First, collect attributes belonging to any class in the MRO, ignoring duplicates. # # We iterate through the MRO in reverse because attrs defined in the parent must appear # earlier in the attributes list than attrs defined in the child. # # However, we also want attributes defined in the subtype to override ones defined # in the parent. We can implement this via a dict without disrupting the attr order # because dicts preserve insertion order in Python 3.7+. found_attrs: Dict[str, TaskAttribute] = {} for info in reversed(cls.info.mro[1:-1]): if METADATA_TAG not in info.metadata: continue # Each class depends on the set of attributes in its task ancestors. self._api.add_plugin_dependency(make_wildcard_trigger(info.fullname)) for data in info.metadata[METADATA_TAG]["attributes"]: name: str = data["name"] attr = TaskAttribute.deserialize(info, data, self._api) # TODO: We shouldn't be performing type operations during the main # semantic analysis pass, since some TypeInfo attributes might # still be in flux. This should be performed in a later phase. attr.expand_typevar_from_subtype(cls.info) found_attrs[name] = attr sym_node = cls.info.names.get(name) if sym_node and sym_node.node and not isinstance(sym_node.node, Var): self._api.fail( "Task attribute may only be overridden by another attribute", sym_node.node, ) # Second, collect attributes belonging to the current class. current_attr_names: set[str] = set() for stmt in self._get_assignment_statements_from_block(cls.defs): if not self.is_parameter_call(stmt.rvalue): continue # a: int, b: str = 1, 'foo' is not supported syntax so we # don't have to worry about it. lhs = stmt.lvalues[0] if not isinstance(lhs, NameExpr): continue sym = cls.info.names.get(lhs.name) if sym is None: # There was probably a semantic analysis error. continue node = sym.node assert not isinstance(node, PlaceholderNode) assert isinstance(node, Var) has_parameter_call, parameter_args = self._collect_parameter_args(stmt.rvalue) has_default = False # Ensure that something like x: int = field() is rejected # after an attribute with a default. if has_parameter_call: has_default = "default" in parameter_args # All other assignments are already type checked. elif not isinstance(stmt.rvalue, TempNode): has_default = True if not has_default: # Make all non-default task attributes implicit because they are de-facto # set on self in the generated __init__(), not in the class body. On the other # hand, we don't know how custom task transforms initialize attributes, # so we don't treat them as implicit. This is required to support descriptors # (https://github.com/python/mypy/issues/14868). sym.implicit = True current_attr_names.add(lhs.name) with state.strict_optional_set(self._api.options.strict_optional): init_type = self._infer_task_attr_init_type(sym, stmt) # When the type annotation is a Parameter type, update the # symbol's type to the resolved type so that mypy uses it # for the __init__ parameter type if init_type is not None and init_type != sym.type: assert isinstance(node, Var) node.type = init_type found_attrs[lhs.name] = TaskAttribute( name=lhs.name, has_default=has_default, line=stmt.line, column=stmt.column, type=init_type, info=cls.info, api=self._api, ) return list(found_attrs.values()) def _collect_parameter_args(self, expr: Expression) -> tuple[bool, Dict[str, Expression]]: """Returns a tuple where the first value represents whether or not the expression is a call to luigi.Parameter() and the second value is a dictionary of the keyword arguments that luigi.Parameter() was called with. """ if isinstance(expr, CallExpr) and isinstance(expr.callee, RefExpr): args = {} for name, arg in zip(expr.arg_names, expr.args): if name is None: # NOTE: this is a workaround to get default value from a parameter self._api.fail( "Positional arguments are not allowed for parameters when using the mypy plugin. " "Update your code to use named arguments, like luigi.Parameter(default='foo') instead of luigi.Parameter('foo')", expr, ) continue args[name] = arg return True, args return False, {} def _infer_task_attr_init_type(self, sym: SymbolTableNode, context: Context) -> Type | None: """Infer __init__ argument type for an attribute. In particular, possibly use the signature of __set__. """ default = sym.type t = get_proper_type(sym.type) # If the type annotation is a Parameter subclass, resolve to the inner type T # e.g. IntParameter -> int, StrParameter -> str if isinstance(t, Instance): is_param = self._task_plugin.check_parameter(t.type.fullname) if is_param: resolved = self._resolve_parameter_type(t) return resolved if sym.implicit: return default # Perform a simple-minded inference from the signature of __set__, if present. # We can't use mypy.checkmember here, since this plugin runs before type checking. # We only support some basic scanerios here, which is hopefully sufficient for # the vast majority of use cases. if not isinstance(t, Instance): return default setter = t.type.get("__set__") if not setter: return default if isinstance(setter.node, FuncDef): super_info = t.type.get_containing_type_info("__set__") assert super_info if setter.type: setter_type = get_proper_type(map_type_from_supertype(setter.type, t.type, super_info)) else: return AnyType(TypeOfAny.unannotated) if isinstance(setter_type, CallableType) and setter_type.arg_kinds == [ ARG_POS, ARG_POS, ARG_POS, ]: return expand_type_by_instance(setter_type.arg_types[2], t) else: self._api.fail(f'Unsupported signature for "__set__" in "{t.type.name}"', context) else: self._api.fail(f'Unsupported "__set__" in "{t.type.name}"', context) return default def is_parameter_call(self, expr: Expression) -> bool: """Checks if the expression is a call to luigi.Parameter()""" if not isinstance(expr, CallExpr): return False callee = expr.callee fullname = None if isinstance(callee, MemberExpr): type_info = callee.node if type_info is None and isinstance(callee.expr, NameExpr): fullname = f"{callee.expr.name}.{callee.name}" elif isinstance(callee, NameExpr): type_info = callee.node else: return False if isinstance(type_info, TypeInfo): fullname = type_info.fullname return fullname is not None and self._task_plugin.check_parameter(fullname) def _resolve_parameter_type(self, t: Instance) -> Type: """Resolve a Parameter type annotation to its inner type T. e.g. IntParameter -> int, Parameter[str] -> str """ # Direct Parameter[T] usage (e.g. Parameter[str]) if t.type.fullname == "luigi.parameter.Parameter" and t.args: return t.args[0] # Parameter subclass (e.g. IntParameter extends Parameter[int]) for base in t.type.bases: if isinstance(base, Instance) and base.type.fullname == "luigi.parameter.Parameter": if base.args: return base.args[0] break return AnyType(TypeOfAny.unannotated) def plugin(version: str) -> type[Plugin]: return TaskPlugin ================================================ FILE: luigi/notifications.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Supports sending emails when tasks fail. This needs some more documentation. See :doc:`/configuration` for configuration options. In particular using the config `receiver` should set up Luigi so that it will send emails when tasks fail. .. code-block:: ini [email] receiver=foo@bar.baz """ import logging import socket import sys import textwrap import luigi.parameter from luigi.task import Config, Task logger = logging.getLogger("luigi-interface") DEFAULT_CLIENT_EMAIL = "luigi-client@%s" % socket.gethostname() class TestNotificationsTask(Task): """ You may invoke this task to quickly check if you correctly have setup your notifications Configuration. You can run: .. code-block:: console $ luigi TestNotificationsTask --local-scheduler --email-force-send And then check your email inbox to see if you got an error email or any other kind of notifications that you expected. """ raise_in_complete = luigi.parameter.BoolParameter(description="If true, fail in complete() instead of run()") def run(self): raise ValueError("Testing notifications triggering") def complete(self): if self.raise_in_complete: raise ValueError("Testing notifications triggering") return False class email(Config): force_send = luigi.parameter.BoolParameter(default=False, description="Send e-mail even from a tty") format = luigi.parameter.ChoiceParameter( default="plain", config_path=dict(section="core", name="email-type"), choices=("plain", "html", "none"), description="Format type for sent e-mails" ) method = luigi.parameter.ChoiceParameter( default="smtp", config_path=dict(section="email", name="type"), choices=("smtp", "sendgrid", "ses", "sns"), description="Method for sending e-mail" ) prefix = luigi.parameter.Parameter(default="", config_path=dict(section="core", name="email-prefix"), description="Prefix for subject lines of all e-mails") receiver = luigi.parameter.Parameter(default="", config_path=dict(section="core", name="error-email"), description="Address to send error e-mails to") traceback_max_length = luigi.parameter.IntParameter(default=5000, description="Max length for error traceback") sender = luigi.parameter.Parameter( default=DEFAULT_CLIENT_EMAIL, config_path=dict(section="core", name="email-sender"), description="Address to send e-mails from" ) class smtp(Config): host = luigi.parameter.Parameter(default="localhost", config_path=dict(section="core", name="smtp_host"), description="Hostname of smtp server") local_hostname = luigi.parameter.Parameter( default=None, config_path=dict(section="core", name="smtp_local_hostname"), description="If specified, local_hostname is used as the FQDN of the local host in the HELO/EHLO command", ) no_tls = luigi.parameter.BoolParameter( default=False, config_path=dict(section="core", name="smtp_without_tls"), description="Do not use TLS in SMTP connections" ) password = luigi.parameter.Parameter(default=None, config_path=dict(section="core", name="smtp_password"), description="Password for the SMTP server login") port = luigi.parameter.IntParameter(default=0, config_path=dict(section="core", name="smtp_port"), description="Port number for smtp server") ssl = luigi.parameter.BoolParameter(default=False, config_path=dict(section="core", name="smtp_ssl"), description="Use SSL for the SMTP connection.") timeout = luigi.parameter.FloatParameter( default=10.0, config_path=dict(section="core", name="smtp_timeout"), description="Number of seconds before timing out the smtp connection" ) username = luigi.parameter.Parameter( default=None, config_path=dict(section="core", name="smtp_login"), description="Username used to log in to the SMTP host" ) class sendgrid(Config): apikey = luigi.parameter.Parameter(config_path=dict(section="email", name="SENGRID_API_KEY"), description="API key for SendGrid login") def generate_email(sender, subject, message, recipients, image_png): from email.mime.image import MIMEImage from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText msg_root = MIMEMultipart("related") msg_text = MIMEText(message, email().format, "utf-8") msg_root.attach(msg_text) if image_png: with open(image_png, "rb") as fp: msg_image = MIMEImage(fp.read(), "png") msg_root.attach(msg_image) msg_root["Subject"] = subject msg_root["From"] = sender msg_root["To"] = ",".join(recipients) return msg_root def wrap_traceback(traceback): """ For internal use only (until further notice) """ if email().format == "html": try: from pygments import highlight from pygments.formatters import HtmlFormatter from pygments.lexers import PythonTracebackLexer with_pygments = True except ImportError: with_pygments = False if with_pygments: formatter = HtmlFormatter(noclasses=True) wrapped = highlight(traceback, PythonTracebackLexer(), formatter) else: wrapped = "
    %s
    " % traceback else: wrapped = traceback return wrapped def send_email_smtp(sender, subject, message, recipients, image_png): import smtplib smtp_config = smtp() kwargs = dict( host=smtp_config.host, port=smtp_config.port, local_hostname=smtp_config.local_hostname, ) if smtp_config.timeout: kwargs["timeout"] = smtp_config.timeout try: smtp_conn = smtplib.SMTP_SSL(**kwargs) if smtp_config.ssl else smtplib.SMTP(**kwargs) smtp_conn.ehlo_or_helo_if_needed() if smtp_conn.has_extn("starttls") and not smtp_config.no_tls: smtp_conn.starttls() if smtp_config.username and smtp_config.password: smtp_conn.login(smtp_config.username, smtp_config.password) msg_root = generate_email(sender, subject, message, recipients, image_png) smtp_conn.sendmail(sender, recipients, msg_root.as_string()) except socket.error as exception: logger.error("Not able to connect to smtp server: %s", exception) def send_email_ses(sender, subject, message, recipients, image_png): """ Sends notification through AWS SES. Does not handle access keys. Use either 1/ configuration file 2/ EC2 instance profile See also https://boto3.readthedocs.io/en/latest/guide/configuration.html. """ from boto3 import client as boto3_client client = boto3_client("ses") msg_root = generate_email(sender, subject, message, recipients, image_png) response = client.send_raw_email(Source=sender, Destinations=recipients, RawMessage={"Data": msg_root.as_string()}) logger.debug( ("Message sent to SES.\nMessageId: {},\nRequestId: {},\nHTTPSStatusCode: {}").format( response["MessageId"], response["ResponseMetadata"]["RequestId"], response["ResponseMetadata"]["HTTPStatusCode"] ) ) def send_email_sendgrid(sender, subject, message, recipients, image_png): import sendgrid as sendgrid_lib client = sendgrid_lib.SendGridAPIClient(sendgrid().apikey) to_send = sendgrid_lib.Mail(from_email=sender, to_emails=recipients, subject=subject) if email().format == "html": to_send.add_content(message, "text/html") else: to_send.add_content(message, "text/plain") if image_png: to_send.add_attachment(image_png) client.send(to_send) def _email_disabled_reason(): if email().format == "none": return "email format is 'none'" elif email().force_send: return None elif sys.stdout.isatty(): return "running from a tty" else: return None def send_email_sns(sender, subject, message, topic_ARN, image_png): """ Sends notification through AWS SNS. Takes Topic ARN from recipients. Does not handle access keys. Use either 1/ configuration file 2/ EC2 instance profile See also https://boto3.readthedocs.io/en/latest/guide/configuration.html. """ from boto3 import resource as boto3_resource sns = boto3_resource("sns") topic = sns.Topic(topic_ARN[0]) # Subject is max 100 chars if len(subject) > 100: subject = subject[0:48] + "..." + subject[-49:] response = topic.publish(Subject=subject, Message=message) logger.debug( ("Message sent to SNS.\nMessageId: {},\nRequestId: {},\nHTTPSStatusCode: {}").format( response["MessageId"], response["ResponseMetadata"]["RequestId"], response["ResponseMetadata"]["HTTPStatusCode"] ) ) def send_email(subject, message, sender, recipients, image_png=None): """ Decides whether to send notification. Notification is cancelled if there are no recipients or if stdout is onto tty or if in debug mode. Dispatches on config value email.method. Default is 'smtp'. """ notifiers = { "ses": send_email_ses, "sendgrid": send_email_sendgrid, "smtp": send_email_smtp, "sns": send_email_sns, } subject = _prefix(subject) if not recipients or recipients == (None,): return if _email_disabled_reason(): logger.info("Not sending email to %r because %s", recipients, _email_disabled_reason()) return # Clean the recipients lists to allow multiple email addresses, comma # separated in luigi.cfg recipients_tmp = [] for r in recipients: recipients_tmp.extend([a.strip() for a in r.split(",") if a.strip()]) # Replace original recipients with the clean list recipients = recipients_tmp logger.info("Sending email to %r", recipients) # Get appropriate sender and call it to send the notification email_sender = notifiers[email().method] email_sender(sender, subject, message, recipients, image_png) def _email_recipients(additional_recipients=None): receiver = email().receiver recipients = [receiver] if receiver else [] if additional_recipients: if isinstance(additional_recipients, str): recipients.append(additional_recipients) else: recipients.extend(additional_recipients) return recipients def send_error_email(subject, message, additional_recipients=None): """ Sends an email to the configured error email, if it's configured. """ recipients = _email_recipients(additional_recipients) sender = email().sender send_email(subject=subject, message=message, sender=sender, recipients=recipients) def _prefix(subject): """ If the config has a special prefix for emails then this function adds this prefix. """ if email().prefix: return "{} {}".format(email().prefix, subject) else: return subject def format_task_error(headline, task, command, formatted_exception=None): """ Format a message body for an error email related to a Task :param headline: Summary line for the message :param task: `Task` instance where this error occurred :param formatted_exception: optional string showing traceback :return: message body """ if formatted_exception: if len(formatted_exception) > email().traceback_max_length: truncated_exception = formatted_exception[: email().traceback_max_length] formatted_exception = f"{truncated_exception}...Traceback exceeds max length and has been truncated." if formatted_exception: formatted_exception = wrap_traceback(formatted_exception) else: formatted_exception = "" if email().format == "html": msg_template = textwrap.dedent("""

    {headline}

    {param_rows}
    name{name}
  • Command line

            {command}
            

    Traceback

    {traceback} """) str_params = task.to_str_params() params = "\n".join("{}{}".format(*items) for items in str_params.items()) body = msg_template.format(headline=headline, name=task.task_family, param_rows=params, command=command, traceback=formatted_exception) else: msg_template = textwrap.dedent("""\ {headline} Name: {name} Parameters: {params} Command line: {command} {traceback} """) str_params = task.to_str_params() max_width = max([0] + [len(x) for x in str_params.keys()]) params = "\n".join(" {:{width}}: {}".format(*items, width=max_width) for items in str_params.items()) body = msg_template.format(headline=headline, name=task.task_family, params=params, command=command, traceback=formatted_exception) return body ================================================ FILE: luigi/parameter.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Parameters are one of the core concepts of Luigi. All Parameters sit on :class:`~luigi.task.Task` classes. See :ref:`Parameter` for more info on how to define parameters. """ import abc import datetime import json import operator import warnings from ast import literal_eval from enum import Enum, IntEnum from json import JSONEncoder from pathlib import Path from typing import ( Any, Callable, Dict, Generic, Iterable, Optional, Sequence, Tuple, Type, TypedDict, Union, overload, ) from typing_extensions import TypeVar, Unpack try: import jsonschema _JSONSCHEMA_ENABLED = True except ImportError: _JSONSCHEMA_ENABLED = False from configparser import NoOptionError, NoSectionError import luigi from luigi import configuration, date_interval, task_register from luigi.cmdline_parser import CmdlineParser from .freezing import FrozenOrderedDict, recursively_freeze, recursively_unfreeze class _NoValueType: """Sentinel class representing "no default value provided".""" _instance: "Optional[_NoValueType]" = None def __new__(cls) -> "_NoValueType": if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance def __repr__(self) -> str: return "" _no_value = _NoValueType() class ParameterVisibility(IntEnum): """ Possible values for the parameter visibility option. Public is the default. See :doc:`/parameters` for more info. """ PUBLIC = 0 HIDDEN = 1 PRIVATE = 2 @classmethod def has_value(cls, value): return any(value == item.value for item in cls) def serialize(self): return self.value class ParameterException(Exception): """ Base exception. """ pass class MissingParameterException(ParameterException): """ Exception signifying that there was a missing Parameter. """ pass class UnknownParameterException(ParameterException): """ Exception signifying that an unknown Parameter was supplied. """ pass class DuplicateParameterException(ParameterException): """ Exception signifying that a Parameter was specified multiple times. """ pass class OptionalParameterTypeWarning(UserWarning): """ Warning class for OptionalParameterMixin with wrong type. """ pass class UnconsumedParameterWarning(UserWarning): """Warning class for parameters that are not consumed by the task.""" T = TypeVar("T", default=str) _OptT = TypeVar("_OptT") class ConfigPath(TypedDict): section: str name: str class _ParameterKwargs(TypedDict, total=False): is_global: bool significant: bool description: Optional[str] config_path: Optional[ConfigPath] positional: bool always_in_help: bool batch_method: Optional[Callable[[Iterable[Any]], Any]] visibility: ParameterVisibility class Parameter(Generic[T]): """ Parameter whose value is a ``str``, and a base class for other parameter types. Parameters are objects set on the Task class level to make it possible to parameterize tasks. For instance: .. code:: python class MyTask(luigi.Task): foo = luigi.Parameter() class RequiringTask(luigi.Task): def requires(self): return MyTask(foo="hello") def run(self): print(self.requires().foo) # prints "hello" This makes it possible to instantiate multiple tasks, eg ``MyTask(foo='bar')`` and ``MyTask(foo='baz')``. The task will then have the ``foo`` attribute set appropriately. When a task is instantiated, it will first use any argument as the value of the parameter, eg. if you instantiate ``a = TaskA(x=44)`` then ``a.x == 44``. When the value is not provided, the value will be resolved in this order of falling priority: * Any value provided on the command line: - To the root task (eg. ``--param xyz``) - Then to the class, using the qualified task name syntax (eg. ``--TaskA-param xyz``). * With ``[TASK_NAME]>PARAM_NAME: `` syntax. See :ref:`ParamConfigIngestion` * Any default value set using the ``default`` flag. Parameter objects may be reused, but you must then set the ``positional=False`` flag. """ _counter = 0 # non-atomically increasing counter used for ordering parameters. def __init__( self, default: Union[T, _NoValueType] = _no_value, is_global: bool = False, significant: bool = True, description: Optional[str] = None, config_path: Optional[ConfigPath] = None, positional: bool = True, always_in_help: bool = False, batch_method: Optional[Callable[[Iterable[Any]], Any]] = None, visibility: ParameterVisibility = ParameterVisibility.PUBLIC, ): """ :param default: the default value for this parameter. This should match the type of the Parameter, i.e. ``datetime.date`` for ``DateParameter`` or ``int`` for ``IntParameter``. By default, no default is stored and the value must be specified at runtime. :param bool significant: specify ``False`` if the parameter should not be treated as part of the unique identifier for a Task. An insignificant Parameter might also be used to specify a password or other sensitive information that should not be made public via the scheduler. Default: ``True``. :param str description: A human-readable string describing the purpose of this Parameter. For command-line invocations, this will be used as the `help` string shown to users. Default: ``None``. :param dict config_path: a dictionary with entries ``section`` and ``name`` specifying a config file entry from which to read the default value for this parameter. DEPRECATED. Default: ``None``. :param bool positional: If true, you can set the argument as a positional argument. It's true by default but we recommend ``positional=False`` for abstract base classes and similar cases. :param bool always_in_help: For the --help option in the command line parsing. Set true to always show in --help. :param function(iterable[A])->A batch_method: Method to combine an iterable of parsed parameter values into a single value. Used when receiving batched parameter lists from the scheduler. See :ref:`batch_method` :param visibility: A Parameter whose value is a :py:class:`~luigi.parameter.ParameterVisibility`. Default value is ParameterVisibility.PUBLIC """ self._default = default self._batch_method = batch_method if is_global: warnings.warn("is_global support is removed. Assuming positional=False", DeprecationWarning, stacklevel=2) positional = False self.significant = significant # Whether different values for this parameter will differentiate otherwise equal tasks self.positional = positional self.visibility = visibility if ParameterVisibility.has_value(visibility) else ParameterVisibility.PUBLIC self.description = description self.always_in_help = always_in_help if config_path is not None and ("section" not in config_path or "name" not in config_path): raise ParameterException("config_path must be a hash containing entries for section and name") self._config_path = config_path self._counter = Parameter._counter # We need to keep track of this to get the order right (see Task class) Parameter._counter += 1 @overload def __get__(self, instance: None, owner: Any) -> "Parameter[T]": ... @overload def __get__(self, instance: Any, owner: Any) -> T: ... def __get__(self, instance: Any, owner: Any) -> Any: if instance is None: return self return instance.__dict__[self._attribute_name] def __set_name__(self, owner, name): self._attribute_name = name def __set__(self, instance: Any, value: T): if self._attribute_name is None: raise RuntimeError("Parameter name not set. ensure it's defined as a class attribute.") instance.__dict__[self._attribute_name] = value def _get_value_from_config(self, section, name): """Loads the default from the config. Returns _no_value if it doesn't exist""" conf = configuration.get_config() try: value = conf.get(section, name) except (NoSectionError, NoOptionError, KeyError): return _no_value return self.parse(value) def _get_value(self, task_name, param_name): for value, warn in self._value_iterator(task_name, param_name): if value != _no_value: if warn: warnings.warn(warn, DeprecationWarning) return value return _no_value def _value_iterator(self, task_name, param_name): """ Yield the parameter values, with optional deprecation warning as second tuple value. The parameter value will be whatever non-_no_value that is yielded first. """ cp_parser = CmdlineParser.get_instance() if cp_parser: dest = self._parser_global_dest(param_name, task_name) found = getattr(cp_parser.known_args, dest, None) yield (self._parse_or_no_value(found), None) yield (self._get_value_from_config(task_name, param_name), None) if self._config_path: yield ( self._get_value_from_config(self._config_path["section"], self._config_path["name"]), "The use of the configuration [{}] {} is deprecated. Please use [{}] {}".format( self._config_path["section"], self._config_path["name"], task_name, param_name ), ) yield (self._default, None) def has_task_value(self, task_name, param_name): return self._get_value(task_name, param_name) != _no_value def task_value(self, task_name, param_name): value = self._get_value(task_name, param_name) if value == _no_value: raise MissingParameterException("No default specified") else: return self.normalize(value) def _is_batchable(self): return self._batch_method is not None def parse(self, x): """ Parse an individual value from the input. The default implementation is the identity function, but subclasses should override this method for specialized parsing. :param str x: the value to parse. :return: the parsed value. """ return x # default impl def _parse_list(self, xs): """ Parse a list of values from the scheduler. Only possible if this is_batchable() is True. This will combine the list into a single parameter value using batch method. This should never need to be overridden. :param xs: list of values to parse and combine :return: the combined parsed values """ if not self._is_batchable(): raise NotImplementedError("No batch method found") elif not xs: raise ValueError("Empty parameter list passed to parse_list") else: return self._batch_method(map(self.parse, xs)) def serialize(self, x): """ Opposite of :py:meth:`parse`. Converts the value ``x`` to a string. :param x: the value to serialize. """ return str(x) def _warn_on_wrong_param_type(self, param_name, param_value): if self.__class__ != Parameter: return if not isinstance(param_value, str): warnings.warn('Parameter "{}" with value "{}" is not of type string.'.format(param_name, param_value)) def normalize(self, x): """ Given a parsed parameter value, normalizes it. The value can either be the result of parse(), the default value or arguments passed into the task's constructor by instantiation. This is very implementation defined, but can be used to validate/clamp valid values. For example, if you wanted to only accept even integers, and "correct" odd values to the nearest integer, you can implement normalize as ``x // 2 * 2``. """ return x # default impl def next_in_enumeration(self, value): """ If your Parameter type has an enumerable ordering of values. You can choose to override this method. This method is used by the :py:mod:`luigi.execution_summary` module for pretty printing purposes. Enabling it to pretty print tasks like ``MyTask(num=1), MyTask(num=2), MyTask(num=3)`` to ``MyTask(num=1..3)``. :param value: The value :return: The next value, like "value + 1". Or ``None`` if there's no enumerable ordering. """ return None def _parse_or_no_value(self, x): if not x: return _no_value else: return self.parse(x) @staticmethod def _parser_global_dest(param_name, task_name): return task_name + "_" + param_name @classmethod def _parser_kwargs(cls, param_name, task_name=None): return { "action": "store", "dest": cls._parser_global_dest(param_name, task_name) if task_name else param_name, } class OptionalParameterMixin(Generic[_OptT]): """ Mixin to make a parameter class optional and treat empty string as None. """ expected_type: type = type(None) def __init__( self, default: Union[_OptT, None, _NoValueType] = _no_value, **kwargs: Unpack[_ParameterKwargs], ): super().__init__(default=default, **kwargs) # type: ignore[arg-type, call-arg, misc] @overload def __get__(self, instance: None, owner: Any) -> "Parameter[Optional[_OptT]]": ... @overload def __get__(self, instance: Any, owner: Any) -> Optional[_OptT]: ... def __get__(self, instance: Any, owner: Any) -> Any: return super().__get__(instance, owner) # type: ignore[misc] def __set__(self, instance: Any, value: Optional[_OptT]): super().__set__(instance, value) # type: ignore[misc] def serialize(self, x): """ Parse the given value if the value is not None else return an empty string. """ if x is None: return "" else: return super().serialize(x) def parse(self, x): """ Parse the given value if it is a string (empty strings are parsed to None). """ if not isinstance(x, str): return x elif x: return super().parse(x) else: return None def normalize(self, x): """ Normalize the given value if it is not None. """ if x is None: return None return super().normalize(x) def _warn_on_wrong_param_type(self, param_name, param_value): if not isinstance(param_value, self.expected_type) and param_value is not None: try: param_type = "any type in " + str([i.__name__ for i in self.expected_type]).replace("'", '"') except TypeError: param_type = f'type "{self.expected_type.__name__}"' warnings.warn( (f'{self.__class__.__name__} "{param_name}" with value "{param_value}" is not of {param_type} or None.'), OptionalParameterTypeWarning, ) def next_in_enumeration(self, value): return None class OptionalParameter(OptionalParameterMixin[str], Parameter[Optional[str]]): """Class to parse optional parameters.""" expected_type = str class OptionalStrParameter(OptionalParameterMixin[str], Parameter[Optional[str]]): """Class to parse optional str parameters.""" expected_type = str _UNIX_EPOCH = datetime.datetime.utcfromtimestamp(0) class _DateParameterBase(Parameter[datetime.date]): """ Base class Parameter for date (not datetime). """ def __init__( self, default: Union[datetime.date, _NoValueType] = _no_value, interval: int = 1, start: Optional[datetime.date] = None, **kwargs: Unpack[_ParameterKwargs], ): super().__init__(default=default, **kwargs) self.interval = interval self.start = start if start is not None else _UNIX_EPOCH.date() @property @abc.abstractmethod def date_format(self): """ Override me with a :py:meth:`~datetime.date.strftime` string. """ pass def parse(self, s): """ Parses a date string formatted like ``YYYY-MM-DD``. """ return datetime.datetime.strptime(s, self.date_format).date() def serialize(self, dt): """ Converts the date to a string using the :py:attr:`~_DateParameterBase.date_format`. """ if dt is None: return str(dt) return dt.strftime(self.date_format) class DateParameter(_DateParameterBase): """ Parameter whose value is a :py:class:`~datetime.date`. A DateParameter is a Date string formatted ``YYYY-MM-DD``. For example, ``2013-07-10`` specifies July 10, 2013. DateParameters are 90% of the time used to be interpolated into file system paths or the like. Here is a gentle reminder of how to interpolate date parameters into strings: .. code:: python class MyTask(luigi.Task): date = luigi.DateParameter() def run(self): templated_path = "/my/path/to/my/dataset/{date:%Y/%m/%d}/" instantiated_path = templated_path.format(date=self.date) # print(instantiated_path) --> /my/path/to/my/dataset/2016/06/09/ # ... use instantiated_path ... To set this parameter to default to the current day. You can write code like this: .. code:: python import datetime class MyTask(luigi.Task): date = luigi.DateParameter(default=datetime.date.today()) """ date_format = "%Y-%m-%d" def next_in_enumeration(self, value): return value + datetime.timedelta(days=self.interval) def normalize(self, x): if x is None: return None if isinstance(x, datetime.datetime): x = x.date() delta = (x - self.start).days % self.interval return x - datetime.timedelta(days=delta) class MonthParameter(DateParameter): """ Parameter whose value is a :py:class:`~datetime.date`, specified to the month (day of :py:class:`~datetime.date` is "rounded" to first of the month). A MonthParameter is a Date string formatted ``YYYY-MM``. For example, ``2013-07`` specifies July of 2013. Task objects constructed from code accept :py:class:`~datetime.date` (ignoring the day value) or :py:class:`~luigi.date_interval.Month`. """ date_format = "%Y-%m" def _add_months(self, date, months): """ Add ``months`` months to ``date``. Unfortunately we can't use timedeltas to add months because timedelta counts in days and there's no foolproof way to add N months in days without counting the number of days per month. """ year = date.year + (date.month + months - 1) // 12 month = (date.month + months - 1) % 12 + 1 return datetime.date(year=year, month=month, day=1) def next_in_enumeration(self, value): return self._add_months(value, self.interval) def normalize(self, x): if x is None: return None if isinstance(x, date_interval.Month): x = x.date_a months_since_start = (x.year - self.start.year) * 12 + (x.month - self.start.month) months_since_start -= months_since_start % self.interval return self._add_months(self.start, months_since_start) class YearParameter(DateParameter): """ Parameter whose value is a :py:class:`~datetime.date`, specified to the year (day and month of :py:class:`~datetime.date` is "rounded" to first day of the year). A YearParameter is a Date string formatted ``YYYY``. Task objects constructed from code accept :py:class:`~datetime.date` (ignoring the month and day values) or :py:class:`~luigi.date_interval.Year`. """ date_format = "%Y" def next_in_enumeration(self, value): return value.replace(year=value.year + self.interval) def normalize(self, x): if x is None: return None if isinstance(x, date_interval.Year): x = x.date_a delta = (x.year - self.start.year) % self.interval return datetime.date(year=x.year - delta, month=1, day=1) class _DatetimeParameterBase(Parameter[datetime.datetime]): """ Base class Parameter for datetime """ def __init__( self, default: Union[datetime.datetime, _NoValueType] = _no_value, interval: int = 1, start: Optional[datetime.datetime] = None, **kwargs: Unpack[_ParameterKwargs], ): super().__init__(default=default, **kwargs) self.interval = interval self.start = start if start is not None else _UNIX_EPOCH @property @abc.abstractmethod def date_format(self): """ Override me with a :py:meth:`~datetime.date.strftime` string. """ pass @property @abc.abstractmethod def _timedelta(self): """ How to move one interval of this type forward (i.e. not counting self.interval). """ pass def parse(self, s): """ Parses a string to a :py:class:`~datetime.datetime`. """ return datetime.datetime.strptime(s, self.date_format) def serialize(self, dt): """ Converts the date to a string using the :py:attr:`~_DatetimeParameterBase.date_format`. """ if dt is None: return str(dt) return dt.strftime(self.date_format) @staticmethod def _convert_to_dt(dt): if not isinstance(dt, datetime.datetime): dt = datetime.datetime.combine(dt, datetime.time.min) return dt def normalize(self, dt): """ Clamp dt to every Nth :py:attr:`~_DatetimeParameterBase.interval` starting at :py:attr:`~_DatetimeParameterBase.start`. """ if dt is None: return None dt = self._convert_to_dt(dt) dt = dt.replace(microsecond=0) # remove microseconds, to avoid float rounding issues. delta = (dt - self.start).total_seconds() granularity = (self._timedelta * self.interval).total_seconds() return dt - datetime.timedelta(seconds=delta % granularity) def next_in_enumeration(self, value): return value + self._timedelta * self.interval class DateHourParameter(_DatetimeParameterBase): """ Parameter whose value is a :py:class:`~datetime.datetime` specified to the hour. A DateHourParameter is a `ISO 8601 `_ formatted date and time specified to the hour. For example, ``2013-07-10T19`` specifies July 10, 2013 at 19:00. """ date_format = "%Y-%m-%dT%H" # ISO 8601 is to use 'T' _timedelta = datetime.timedelta(hours=1) class DateMinuteParameter(_DatetimeParameterBase): """ Parameter whose value is a :py:class:`~datetime.datetime` specified to the minute. A DateMinuteParameter is a `ISO 8601 `_ formatted date and time specified to the minute. For example, ``2013-07-10T1907`` specifies July 10, 2013 at 19:07. The interval parameter can be used to clamp this parameter to every N minutes, instead of every minute. """ date_format = "%Y-%m-%dT%H%M" _timedelta = datetime.timedelta(minutes=1) deprecated_date_format = "%Y-%m-%dT%HH%M" def parse(self, x): try: value = datetime.datetime.strptime(x, self.deprecated_date_format) warnings.warn('Using "H" between hours and minutes is deprecated, omit it instead.', DeprecationWarning, stacklevel=2) return value except ValueError: return super().parse(x) class DateSecondParameter(_DatetimeParameterBase): """ Parameter whose value is a :py:class:`~datetime.datetime` specified to the second. A DateSecondParameter is a `ISO 8601 `_ formatted date and time specified to the second. For example, ``2013-07-10T190738`` specifies July 10, 2013 at 19:07:38. The interval parameter can be used to clamp this parameter to every N seconds, instead of every second. """ date_format = "%Y-%m-%dT%H%M%S" _timedelta = datetime.timedelta(seconds=1) class StrParameter(Parameter[str]): """ Parameter whose value is a ``str``. """ def parse(self, x): return str(x) class IntParameter(Parameter[int]): """ Parameter whose value is an ``int``. """ def parse(self, x): """ Parses an ``int`` from the string using ``int()``. """ return int(x) def next_in_enumeration(self, value): return value + 1 class OptionalIntParameter(OptionalParameterMixin[int], IntParameter): # type: ignore[misc] """Class to parse optional int parameters.""" expected_type = int class FloatParameter(Parameter[float]): """ Parameter whose value is a ``float``. """ def parse(self, x): """ Parses a ``float`` from the string using ``float()``. """ return float(x) class OptionalFloatParameter(OptionalParameterMixin[float], FloatParameter): # type: ignore[misc] """Class to parse optional float parameters.""" expected_type = float class BoolParameter(Parameter[bool]): """ A Parameter whose value is a ``bool``. This parameter has an implicit default value of ``False``. For the command line interface this means that the value is ``False`` unless you add ``"--the-bool-parameter"`` to your command without giving a parameter value. This is considered *implicit* parsing (the default). However, in some situations one might want to give the explicit bool value (``"--the-bool-parameter true|false"``), e.g. when you configure the default value to be ``True``. This is called *explicit* parsing. When omitting the parameter value, it is still considered ``True`` but to avoid ambiguities during argument parsing, make sure to always place bool parameters behind the task family on the command line when using explicit parsing. You can toggle between the two parsing modes on a per-parameter base via .. code-block:: python class MyTask(luigi.Task): implicit_bool = luigi.BoolParameter(parsing=luigi.BoolParameter.IMPLICIT_PARSING) explicit_bool = luigi.BoolParameter(parsing=luigi.BoolParameter.EXPLICIT_PARSING) or globally by .. code-block:: python luigi.BoolParameter.parsing = luigi.BoolParameter.EXPLICIT_PARSING for all bool parameters instantiated after this line. """ IMPLICIT_PARSING = "implicit" EXPLICIT_PARSING = "explicit" parsing = IMPLICIT_PARSING def __init__( self, default: Union[bool, _NoValueType] = _no_value, parsing: str = IMPLICIT_PARSING, **kwargs: Unpack[_ParameterKwargs], ): self.parsing = parsing super().__init__(default=default, **kwargs) if self._default == _no_value: self._default = False def parse(self, x): """ Parses a ``bool`` from the string, matching 'true' or 'false' ignoring case. """ s = str(x).lower() if s == "true": return True elif s == "false": return False else: raise ValueError("cannot interpret '{}' as boolean".format(x)) def normalize(self, x): try: return self.parse(x) except ValueError: return None def _parser_kwargs(self, *args, **kwargs): parser_kwargs = super()._parser_kwargs(*args, **kwargs) if self.parsing == self.IMPLICIT_PARSING: parser_kwargs["action"] = "store_true" elif self.parsing == self.EXPLICIT_PARSING: parser_kwargs["nargs"] = "?" parser_kwargs["const"] = True else: raise ValueError("unknown parsing value '{}'".format(self.parsing)) return parser_kwargs class OptionalBoolParameter(OptionalParameterMixin[bool], BoolParameter): # type: ignore[misc] """Class to parse optional bool parameters.""" expected_type = bool class DateIntervalParameter(Parameter[date_interval.DateInterval]): """ A Parameter whose value is a :py:class:`~luigi.date_interval.DateInterval`. Date Intervals are specified using the ISO 8601 date notation for dates (eg. "2015-11-04"), months (eg. "2015-05"), years (eg. "2015"), or weeks (eg. "2015-W35"). In addition, it also supports arbitrary date intervals provided as two dates separated with a dash (eg. "2015-11-04-2015-12-04"). """ def parse(self, x): """ Parses a :py:class:`~luigi.date_interval.DateInterval` from the input. see :py:mod:`luigi.date_interval` for details on the parsing of DateIntervals. """ # TODO: can we use xml.utils.iso8601 or something similar? from luigi import date_interval as d for cls in [d.Year, d.Month, d.Week, d.Date, d.Custom]: i = cls.parse(x) if i: return i raise ValueError("Invalid date interval - could not be parsed") class TimeDeltaParameter(Parameter[datetime.timedelta]): """ Class that maps to timedelta using strings in any of the following forms: * A bare number is interpreted as duration in seconds. * ``n {w[eek[s]]|d[ay[s]]|h[our[s]]|m[inute[s]|s[second[s]]}`` (e.g. "1 week 2 days" or "1 h") Note: multiple arguments must be supplied in longest to shortest unit order * ISO 8601 duration ``PnDTnHnMnS`` (each field optional, years and months not supported) * ISO 8601 duration ``PnW`` See https://en.wikipedia.org/wiki/ISO_8601#Durations """ def _apply_regex(self, regex, input): import re re_match = re.match(regex, input) if re_match and any(re_match.groups()): kwargs = {} has_val = False for k, v in re_match.groupdict(default="0").items(): val = int(v) if val > -1: has_val = True kwargs[k] = val if has_val: return datetime.timedelta(**kwargs) def _parseIso8601(self, input): def field(key): return r"(?P<%s>\d+)%s" % (key, key[0].upper()) def optional_field(key): return "(%s)?" % field(key) # A little loose: ISO 8601 does not allow weeks in combination with other fields, but this regex does (as does python timedelta) regex = "P(%s|%s(T%s)?)" % (field("weeks"), optional_field("days"), "".join([optional_field(key) for key in ["hours", "minutes", "seconds"]])) return self._apply_regex(regex, input) def _parseSimple(self, input): keys = ["weeks", "days", "hours", "minutes", "seconds"] # Give the digits a regex group name from the keys, then look for text with the first letter of the key, # optionally followed by the rest of the word, with final char (the "s") optional regex = "".join([r"((?P<%s>\d+) ?%s(%s)?(%s)? ?)?" % (k, k[0], k[1:-1], k[-1]) for k in keys]) return self._apply_regex(regex, input) def parse(self, x): """ Parses a time delta from the input. See :py:class:`TimeDeltaParameter` for details on supported formats. """ try: return datetime.timedelta(seconds=float(x)) except ValueError: pass result = self._parseIso8601(x) if not result: result = self._parseSimple(x) if result is not None: return result else: raise ParameterException("Invalid time delta - could not parse %s" % x) def serialize(self, x): """ Converts datetime.timedelta to a string :param x: the value to serialize. """ weeks = x.days // 7 days = x.days % 7 hours = x.seconds // 3600 minutes = (x.seconds % 3600) // 60 seconds = (x.seconds % 3600) % 60 result = "{} w {} d {} h {} m {} s".format(weeks, days, hours, minutes, seconds) return result def _warn_on_wrong_param_type(self, param_name, param_value): if self.__class__ != TimeDeltaParameter: return if not isinstance(param_value, datetime.timedelta): warnings.warn('Parameter "{}" with value "{}" is not of type timedelta.'.format(param_name, param_value)) TaskType = TypeVar("TaskType", bound="luigi.task.Task") class TaskParameter(Parameter[Type[TaskType]]): """ A parameter that takes another luigi task class. When used programatically, the parameter should be specified directly with the :py:class:`luigi.task.Task` (sub) class. Like ``MyMetaTask(my_task_param=my_tasks.MyTask)``. On the command line, you specify the :py:meth:`luigi.task.Task.get_task_family`. Like .. code-block:: console $ luigi --module my_tasks MyMetaTask --my_task_param my_namespace.MyTask Where ``my_namespace.MyTask`` is defined in the ``my_tasks`` python module. When the :py:class:`luigi.task.Task` class is instantiated to an object. The value will always be a task class (and not a string). """ def parse(self, x): """ Parse a task_famly using the :class:`~luigi.task_register.Register` """ return task_register.Register.get_task_cls(x) def serialize(self, x): """ Converts the :py:class:`luigi.task.Task` (sub) class to its family name. """ return x.get_task_family() EnumParameterType = TypeVar("EnumParameterType", bound=Enum) class EnumParameter(Parameter[EnumParameterType]): """ A parameter whose value is an :class:`~enum.Enum`. In the task definition, use .. code-block:: python class Model(enum.Enum): Honda = 1 Volvo = 2 class MyTask(luigi.Task): my_param = luigi.EnumParameter(enum=Model) At the command line, use, .. code-block:: console $ luigi --module my_tasks MyTask --my-param Honda """ def __init__( self, default: Union[EnumParameterType, _NoValueType] = _no_value, *, enum: Optional[Type[EnumParameterType]] = None, **kwargs: Unpack[_ParameterKwargs], ): if enum is None: raise ParameterException("An enum class must be specified.") self._enum = enum super().__init__(default=default, **kwargs) def parse(self, x): try: return self._enum[x] except KeyError: raise ValueError("Invalid enum value - could not be parsed") def serialize(self, x): return x.name class EnumListParameter(Parameter[Tuple[EnumParameterType, ...]]): """ A parameter whose value is a comma-separated list of :class:`~enum.Enum`. Values should come from the same enum. Values are taken to be a list, i.e. order is preserved, duplicates may occur, and empty list is possible. In the task definition, use .. code-block:: python class Model(enum.Enum): Honda = 1 Volvo = 2 class MyTask(luigi.Task): my_param = luigi.EnumListParameter(enum=Model) At the command line, use, .. code-block:: console $ luigi --module my_tasks MyTask --my-param Honda,Volvo """ _sep = "," def __init__( self, default: Union[Tuple[EnumParameterType, ...], _NoValueType] = _no_value, *, enum: Optional[Type[EnumParameterType]] = None, **kwargs: Unpack[_ParameterKwargs], ): if enum is None: raise ParameterException("An enum class must be specified.") self._enum = enum super().__init__(default=default, **kwargs) def parse(self, x): values = [] if x == "" else x.split(self._sep) for i, v in enumerate(values): try: values[i] = self._enum[v] except KeyError: raise ValueError('Invalid enum value "{}" index {} - could not be parsed'.format(v, i)) return tuple(values) def serialize(self, x): return self._sep.join([e.name for e in x]) class _DictParamEncoder(JSONEncoder): """ JSON encoder for :py:class:`~DictParameter`, which makes :py:class:`~FrozenOrderedDict` JSON serializable. """ def default(self, obj): if isinstance(obj, FrozenOrderedDict): return obj.get_wrapped() return json.JSONEncoder.default(self, obj) DictT = TypeVar("DictT", bound=dict, default=Dict[Any, Any]) class DictParameter(Parameter[DictT]): """ Parameter whose value is a ``dict``. In the task definition, use .. code-block:: python class MyTask(luigi.Task): tags = luigi.DictParameter() def run(self): logging.info("Find server with role: %s", self.tags['role']) server = aws.ec2.find_my_resource(self.tags) At the command line, use .. code-block:: console $ luigi --module my_tasks MyTask --tags Simple example with two tags: .. code-block:: console $ luigi --module my_tasks MyTask --tags '{"role": "web", "env": "staging"}' It can be used to define dynamic parameters, when you do not know the exact list of your parameters (e.g. list of tags, that are dynamically constructed outside Luigi), or you have a complex parameter containing logically related values (like a database connection config). It is possible to provide a JSON schema that should be validated by the given value: .. code-block:: python class MyTask(luigi.Task): tags = luigi.DictParameter( schema={ "type": "object", "patternProperties": { ".*": {"type": "string", "enum": ["web", "staging"]}, } } ) def run(self): logging.info("Find server with role: %s", self.tags['role']) server = aws.ec2.find_my_resource(self.tags) Using this schema, the following command will work: .. code-block:: console $ luigi --module my_tasks MyTask --tags '{"role": "web", "env": "staging"}' while this command will fail because the parameter is not valid: .. code-block:: console $ luigi --module my_tasks MyTask --tags '{"role": "UNKNOWN_VALUE", "env": "staging"}' Finally, the provided schema can be a custom validator: .. code-block:: python custom_validator = jsonschema.Draft4Validator( schema={ "type": "object", "patternProperties": { ".*": {"type": "string", "enum": ["web", "staging"]}, } } ) class MyTask(luigi.Task): tags = luigi.DictParameter(schema=custom_validator) def run(self): logging.info("Find server with role: %s", self.tags['role']) server = aws.ec2.find_my_resource(self.tags) """ def __init__( self, default: Union[DictT, _NoValueType] = _no_value, *, schema=None, **kwargs: Unpack[_ParameterKwargs], ): if schema is not None and not _JSONSCHEMA_ENABLED: warnings.warn("The 'jsonschema' package is not installed so the parameter can not be validated even though a schema is given.") self.schema = None else: self.schema = schema super().__init__(default=default, **kwargs) def normalize(self, x): """ Ensure that dictionary parameter is converted to a FrozenOrderedDict so it can be hashed. """ if self.schema is not None: unfrozen_value = recursively_unfreeze(x) try: self.schema.validate(unfrozen_value) x = unfrozen_value # Validators may update the instance inplace except AttributeError: jsonschema.validate(instance=unfrozen_value, schema=self.schema) return recursively_freeze(x) def parse(self, x): """ Parses an immutable and ordered ``dict`` from a JSON string using standard JSON library. We need to use an immutable dictionary, to create a hashable parameter and also preserve the internal structure of parsing. The traversal order of standard ``dict`` is undefined, which can result various string representations of this parameter, and therefore a different task id for the task containing this parameter. This is because task id contains the hash of parameters' JSON representation. :param s: String to be parse """ # TOML based config convert params to python types itself. if not isinstance(x, str): return x return json.loads(x, object_pairs_hook=FrozenOrderedDict) def serialize(self, x): return json.dumps(x, cls=_DictParamEncoder) class OptionalDictParameter(OptionalParameterMixin[FrozenOrderedDict], DictParameter): # type: ignore[misc] """Class to parse optional dict parameters.""" expected_type = FrozenOrderedDict ListT = TypeVar("ListT", bound=tuple, default=Tuple[Any, ...]) class ListParameter(Parameter[ListT]): """ Parameter whose value is a ``list``. In the task definition, use .. code-block:: python class MyTask(luigi.Task): grades = luigi.ListParameter() def run(self): sum = 0 for element in self.grades: sum += element avg = sum / len(self.grades) At the command line, use .. code-block:: console $ luigi --module my_tasks MyTask --grades Simple example with two grades: .. code-block:: console $ luigi --module my_tasks MyTask --grades '[100,70]' It is possible to provide a JSON schema that should be validated by the given value: .. code-block:: python class MyTask(luigi.Task): grades = luigi.ListParameter( schema={ "type": "array", "items": { "type": "number", "minimum": 0, "maximum": 10 }, "minItems": 1 } ) def run(self): sum = 0 for element in self.grades: sum += element avg = sum / len(self.grades) Using this schema, the following command will work: .. code-block:: console $ luigi --module my_tasks MyTask --numbers '[1, 8.7, 6]' while these commands will fail because the parameter is not valid: .. code-block:: console $ luigi --module my_tasks MyTask --numbers '[]' # must have at least 1 element $ luigi --module my_tasks MyTask --numbers '[-999, 999]' # elements must be in [0, 10] Finally, the provided schema can be a custom validator: .. code-block:: python custom_validator = jsonschema.Draft4Validator( schema={ "type": "array", "items": { "type": "number", "minimum": 0, "maximum": 10 }, "minItems": 1 } ) class MyTask(luigi.Task): grades = luigi.ListParameter(schema=custom_validator) def run(self): sum = 0 for element in self.grades: sum += element avg = sum / len(self.grades) """ def __init__( self, default: Union[ListT, _NoValueType] = _no_value, *, schema=None, **kwargs: Unpack[_ParameterKwargs], ): if schema is not None and not _JSONSCHEMA_ENABLED: warnings.warn("The 'jsonschema' package is not installed so the parameter can not be validated even though a schema is given.") self.schema = None else: self.schema = schema super().__init__(default=default, **kwargs) def normalize(self, x): """ Ensure that struct is recursively converted to a tuple so it can be hashed. :param str x: the value to parse. :return: the normalized (hashable/immutable) value. """ if self.schema is not None: unfrozen_value = recursively_unfreeze(x) try: self.schema.validate(unfrozen_value) x = unfrozen_value # Validators may update the instance inplace except AttributeError: jsonschema.validate(instance=unfrozen_value, schema=self.schema) return recursively_freeze(x) def parse(self, x): """ Parse an individual value from the input. :param str x: the value to parse. :return: the parsed value. """ i = json.loads(x, object_pairs_hook=FrozenOrderedDict) if i is None: return None return list(i) def serialize(self, x): """ Opposite of :py:meth:`parse`. Converts the value ``x`` to a string. :param x: the value to serialize. """ return json.dumps(x, cls=_DictParamEncoder) class OptionalListParameter(OptionalParameterMixin[ListT], ListParameter): # type: ignore[misc] """Class to parse optional list parameters.""" expected_type = tuple class TupleParameter(ListParameter[ListT]): """ Parameter whose value is a ``tuple`` or ``tuple`` of tuples. In the task definition, use .. code-block:: python class MyTask(luigi.Task): book_locations = luigi.TupleParameter() def run(self): for location in self.book_locations: print("Go to page %d, line %d" % (location[0], location[1])) At the command line, use .. code-block:: console $ luigi --module my_tasks MyTask --book_locations Simple example with two grades: .. code-block:: console $ luigi --module my_tasks MyTask --book_locations '((12,3),(4,15),(52,1))' """ def parse(self, x): """ Parse an individual value from the input. :param str x: the value to parse. :return: the parsed value. """ # Since the result of json.dumps(tuple) differs from a tuple string, we must handle either case. # A tuple string may come from a config file or from cli execution. # t = ((1, 2), (3, 4)) # t_str = '((1,2),(3,4))' # t_json_str = json.dumps(t) # t_json_str == '[[1, 2], [3, 4]]' # json.loads(t_json_str) == t # json.loads(t_str) == ValueError: No JSON object could be decoded # Therefore, if json.loads(x) returns a ValueError, try ast.literal_eval(x). # ast.literal_eval(t_str) == t try: # loop required to parse tuple of tuples return tuple(self._convert_iterable_to_tuple(x) for x in json.loads(x, object_pairs_hook=FrozenOrderedDict)) except (ValueError, TypeError): result = literal_eval(x) # t_str = '("abcd")' # Ensure that the result is not a string to avoid cases like ('a','b','c','d') if isinstance(result, str): raise ValueError("Parsed result cannot be a string") return tuple(result) # if this causes an error, let that error be raised. def _convert_iterable_to_tuple(self, x): if isinstance(x, str): return x return tuple(x) class OptionalTupleParameter(OptionalParameterMixin[ListT], TupleParameter): # type: ignore[misc] """Class to parse optional tuple parameters.""" expected_type = tuple NumericalType = TypeVar("NumericalType", int, float) class NumericalParameter(Parameter[NumericalType]): """ Parameter whose value is a number of the specified type, e.g. ``int`` or ``float`` and in the range specified. In the task definition, use .. code-block:: python class MyTask(luigi.Task): my_param_1 = luigi.NumericalParameter( var_type=int, min_value=-3, max_value=7) # -3 <= my_param_1 < 7 my_param_2 = luigi.NumericalParameter( var_type=int, min_value=-3, max_value=7, left_op=operator.lt, right_op=operator.le) # -3 < my_param_2 <= 7 At the command line, use .. code-block:: console $ luigi --module my_tasks MyTask --my-param-1 -3 --my-param-2 -2 """ def __init__( self, default: Union[NumericalType, _NoValueType] = _no_value, *, var_type: Optional[Type[NumericalType]] = None, min_value: Optional[NumericalType] = None, max_value: Optional[NumericalType] = None, left_op=operator.le, right_op=operator.lt, **kwargs: Unpack[_ParameterKwargs], ): """ :param function var_type: The type of the input variable, e.g. int or float. :param min_value: The minimum value permissible in the accepted values range. May be inclusive or exclusive based on left_op parameter. This should be the same type as var_type. :param max_value: The maximum value permissible in the accepted values range. May be inclusive or exclusive based on right_op parameter. This should be the same type as var_type. :param function left_op: The comparison operator for the left-most comparison in the expression ``min_value left_op value right_op value``. This operator should generally be either ``operator.lt`` or ``operator.le``. Default: ``operator.le``. :param function right_op: The comparison operator for the right-most comparison in the expression ``min_value left_op value right_op value``. This operator should generally be either ``operator.lt`` or ``operator.le``. Default: ``operator.lt``. """ if var_type is None: raise ParameterException("var_type must be specified") self._var_type: Type[NumericalType] = var_type if min_value is None: raise ParameterException("min_value must be specified") self._min_value: NumericalType = min_value if max_value is None: raise ParameterException("max_value must be specified") self._max_value: NumericalType = max_value self._left_op = left_op self._right_op = right_op self._permitted_range = "{var_type} in {left_endpoint}{min_value}, {max_value}{right_endpoint}".format( var_type=self._var_type.__name__, min_value=self._min_value, max_value=self._max_value, left_endpoint="[" if left_op == operator.le else "(", right_endpoint=")" if right_op == operator.lt else "]", ) super().__init__(default=default, **kwargs) # type: ignore[arg-type] if self.description: self.description += " " else: self.description = "" self.description += "permitted values: " + self._permitted_range def parse(self, x): value = self._var_type(x) if self._left_op(self._min_value, value) and self._right_op(value, self._max_value): return value else: raise ValueError("{s} is not in the set of {permitted_range}".format(s=x, permitted_range=self._permitted_range)) class OptionalNumericalParameter(OptionalParameterMixin[NumericalType], NumericalParameter[NumericalType]): # type: ignore[misc] """Class to parse optional numerical parameters.""" def __init__( self, default: Union[Optional[NumericalType], _NoValueType] = _no_value, **kwargs: Unpack[_ParameterKwargs], ): NumericalParameter.__init__(self, default=default, **kwargs) # type: ignore[arg-type, misc] self.expected_type = self._var_type ChoiceType = TypeVar("ChoiceType", default=str) class ChoiceParameter(Parameter[ChoiceType]): """ A parameter which takes two values: 1. an instance of :class:`~collections.Iterable` and 2. the class of the variables to convert to. In the task definition, use .. code-block:: python class MyTask(luigi.Task): my_param = luigi.ChoiceParameter(choices=[0.1, 0.2, 0.3], var_type=float) At the command line, use .. code-block:: console $ luigi --module my_tasks MyTask --my-param 0.1 Consider using :class:`~luigi.EnumParameter` for a typed, structured alternative. This class can perform the same role when all choices are the same type and transparency of parameter value on the command line is desired. """ def __init__( self, default: Union[ChoiceType, _NoValueType] = _no_value, *, choices: Optional[Sequence[ChoiceType]] = None, var_type: Type[ChoiceType] = str, # type: ignore[assignment] **kwargs: Unpack[_ParameterKwargs], ): """ :param function var_type: The type of the input variable, e.g. str, int, float, etc. Default: str :param choices: An iterable, all of whose elements are of `var_type` to restrict parameter choices to. """ if choices is None: raise ParameterException("A choices iterable must be specified") self._choices = set(choices) self._var_type = var_type assert all(type(choice) is self._var_type for choice in self._choices), "Invalid type in choices" super().__init__(default=default, **kwargs) if self.description: self.description += " " else: self.description = "" self.description += "Choices: {" + ", ".join(str(choice) for choice in self._choices) + "}" def parse(self, x): var = self._var_type(x) return self.normalize(var) def normalize(self, x): if x in self._choices: return x else: raise ValueError("{var} is not a valid choice from {choices}".format(var=x, choices=self._choices)) class ChoiceListParameter(ChoiceParameter[ChoiceType]): """ A parameter which takes two values: 1. an instance of :class:`~collections.Iterable` and 2. the class of the variables to convert to. Values are taken to be a list, i.e. order is preserved, duplicates may occur, and empty list is possible. In the task definition, use .. code-block:: python class MyTask(luigi.Task): my_param = luigi.ChoiceListParameter(choices=['foo', 'bar', 'baz'], var_type=str) At the command line, use .. code-block:: console $ luigi --module my_tasks MyTask --my-param foo,bar Consider using :class:`~luigi.EnumListParameter` for a typed, structured alternative. This class can perform the same role when all choices are the same type and transparency of parameter value on the command line is desired. """ _sep = "," @overload # type: ignore[override] def __get__(self, instance: None, owner: Any) -> "Parameter[Tuple[ChoiceType, ...]]": ... @overload def __get__(self, instance: Any, owner: Any) -> Tuple[ChoiceType, ...]: ... def __get__(self, instance: Any, owner: Any) -> Any: return super().__get__(instance, owner) def __init__( self, default: Union[Tuple[ChoiceType, ...], _NoValueType] = _no_value, var_type: Type[ChoiceType] = str, # type: ignore[assignment] choices: Optional[Sequence[ChoiceType]] = None, **kwargs: Unpack[_ParameterKwargs], ): super().__init__(default=default, var_type=var_type, choices=choices, **kwargs) # type: ignore[arg-type] def parse(self, x): values = [] if x == "" else x.split(self._sep) return self.normalize(map(self._var_type, values)) def normalize(self, x): values = [] for v in x: values.append(super().normalize(v)) return tuple(values) def serialize(self, x): return self._sep.join(x) class OptionalChoiceParameter(OptionalParameterMixin[ChoiceType], ChoiceParameter[ChoiceType]): # type: ignore[misc] """Class to parse optional choice parameters.""" def __init__( self, default: Union[Optional[ChoiceType], _NoValueType] = _no_value, var_type: Type[ChoiceType] = str, # type: ignore[assignment] choices: Optional[Sequence[ChoiceType]] = None, **kwargs: Unpack[_ParameterKwargs], ): ChoiceParameter.__init__(self, default=default, var_type=var_type, choices=choices, **kwargs) # type: ignore[arg-type, misc] self.expected_type = self._var_type class PathParameter(Parameter[Path]): """ Parameter whose value is a path. In the task definition, use .. code-block:: python class MyTask(luigi.Task): existing_file_path = luigi.PathParameter(exists=True) new_file_path = luigi.PathParameter() def run(self): # Get data from existing file with self.existing_file_path.open("r", encoding="utf-8") as f: data = f.read() # Output message in new file self.new_file_path.parent.mkdir(parents=True, exist_ok=True) with self.new_file_path.open("w", encoding="utf-8") as f: f.write("hello from a PathParameter => ") f.write(data) At the command line, use .. code-block:: console $ luigi --module my_tasks MyTask --existing-file-path --new-file-path """ def __init__( self, default: Union[Path, _NoValueType] = _no_value, *, absolute: bool = False, exists: bool = False, **kwargs: Unpack[_ParameterKwargs], ): """ :param bool absolute: If set to ``True``, the given path is converted to an absolute path. :param bool exists: If set to ``True``, a :class:`ValueError` is raised if the path does not exist. """ super().__init__(default=default, **kwargs) self.absolute = absolute self.exists = exists def normalize(self, x): """ Normalize the given value to a :class:`pathlib.Path` object. """ path = Path(x) if self.absolute: path = path.absolute() if self.exists and not path.exists(): raise ValueError(f"The path {path} does not exist.") return path class OptionalPathParameter(OptionalParameter, PathParameter): # type: ignore[misc] """Class to parse optional path parameters.""" expected_type = (str, Path) # type: ignore[assignment] ================================================ FILE: luigi/process.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ Contains some helper functions to run luigid in daemon mode """ import datetime import logging import logging.handlers import os rootlogger = logging.getLogger() server_logger = logging.getLogger("luigi.server") def check_pid(pidfile): if pidfile and os.path.exists(pidfile): try: pid = int(open(pidfile).read().strip()) os.kill(pid, 0) return pid except BaseException: return 0 return 0 def write_pid(pidfile): server_logger.info("Writing pid file") piddir = os.path.dirname(pidfile) if piddir != "": try: os.makedirs(piddir) except OSError: pass with open(pidfile, "w") as fobj: fobj.write(str(os.getpid())) def get_log_format(): return "%(asctime)s %(name)s[%(process)s] %(levelname)s: %(message)s" def get_spool_handler(filename): handler = logging.handlers.TimedRotatingFileHandler( filename=filename, when="d", encoding="utf8", backupCount=7, # keep one week of historical logs ) formatter = logging.Formatter(get_log_format()) handler.setFormatter(formatter) return handler def _server_already_running(pidfile): existing_pid = check_pid(pidfile) if pidfile and existing_pid: return True return False def daemonize(cmd, pidfile=None, logdir=None, api_port=8082, address=None, unix_socket=None): import daemon logdir = logdir or "/var/log/luigi" if not os.path.exists(logdir): os.makedirs(logdir) log_path = os.path.join(logdir, "luigi-server.log") # redirect stdout/stderr today = datetime.date.today() stdout_path = os.path.join(logdir, "luigi-server-{0:%Y-%m-%d}.out".format(today)) stderr_path = os.path.join(logdir, "luigi-server-{0:%Y-%m-%d}.err".format(today)) stdout_proxy = open(stdout_path, "a+") stderr_proxy = open(stderr_path, "a+") try: ctx = daemon.DaemonContext( stdout=stdout_proxy, stderr=stderr_proxy, working_directory=".", initgroups=False, ) except TypeError: # Older versions of python-daemon cannot deal with initgroups arg. ctx = daemon.DaemonContext( stdout=stdout_proxy, stderr=stderr_proxy, working_directory=".", ) with ctx: loghandler = get_spool_handler(log_path) rootlogger.addHandler(loghandler) if pidfile: server_logger.info("Checking pid file") existing_pid = check_pid(pidfile) if pidfile and existing_pid: server_logger.info("Server already running (pid=%s)", existing_pid) return write_pid(pidfile) cmd(api_port=api_port, address=address, unix_socket=unix_socket) ================================================ FILE: luigi/py.typed ================================================ ================================================ FILE: luigi/retcodes.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2015-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ Module containing the logic for exit codes for the luigi binary. It's useful when you in a programmatic way need to know if luigi actually finished the given task, and if not why. """ import logging import sys import luigi from luigi import IntParameter from luigi.setup_logging import InterfaceLogging class retcode(luigi.Config): """ See the :ref:`return codes configuration section `. """ # default value inconsistent with doc/configuration.rst for backwards compatibility reasons unhandled_exception = IntParameter( default=4, description="For internal luigi errors.", ) # default value inconsistent with doc/configuration.rst for backwards compatibility reasons missing_data = IntParameter( default=0, description="For when there are incomplete ExternalTask dependencies.", ) # default value inconsistent with doc/configuration.rst for backwards compatibility reasons task_failed = IntParameter( default=0, description="""For when a task's run() method fails.""", ) # default value inconsistent with doc/configuration.rst for backwards compatibility reasons already_running = IntParameter( default=0, description='For both local --lock and luigid "lock"', ) # default value inconsistent with doc/configuration.rst for backwards compatibility reasons scheduling_error = IntParameter( default=0, description="""For when a task's complete() or requires() fails, or task-limit reached""", ) # default value inconsistent with doc/configuration.rst for backwards compatibility reasons not_run = IntParameter(default=0, description="For when a task is not granted run permission by the scheduler.") def run_with_retcodes(argv): """ Run luigi with command line parsing, but raise ``SystemExit`` with the configured exit code. Note: Usually you use the luigi binary directly and don't call this function yourself. :param argv: Should (conceptually) be ``sys.argv[1:]`` """ logger = logging.getLogger("luigi-interface") with luigi.cmdline_parser.CmdlineParser.global_instance(argv): retcodes = retcode() worker = None try: worker = luigi.interface._run(argv).worker except luigi.interface.PidLockAlreadyTakenExit: sys.exit(retcodes.already_running) except Exception: # Some errors occur before logging is set up, we set it up now env_params = luigi.interface.core() InterfaceLogging.setup(env_params) logger.exception("Uncaught exception in luigi") sys.exit(retcodes.unhandled_exception) with luigi.cmdline_parser.CmdlineParser.global_instance(argv): task_sets = luigi.execution_summary._summary_dict(worker) root_task = luigi.execution_summary._root_task(worker) non_empty_categories = {k: v for k, v in task_sets.items() if v}.keys() def has(status): assert status in luigi.execution_summary._ORDERED_STATUSES return status in non_empty_categories codes_and_conds = ( (retcodes.missing_data, has("still_pending_ext")), (retcodes.task_failed, has("failed")), (retcodes.already_running, has("run_by_other_worker")), (retcodes.scheduling_error, has("scheduling_error")), (retcodes.not_run, has("not_run")), ) expected_ret_code = max(code * (1 if cond else 0) for code, cond in codes_and_conds) if expected_ret_code == 0 and root_task not in task_sets["completed"] and root_task not in task_sets["already_done"]: sys.exit(retcodes.not_run) else: sys.exit(expected_ret_code) ================================================ FILE: luigi/rpc.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ Implementation of the REST interface between the workers and the server. rpc.py implements the client side of it, server.py implements the server side. See :doc:`/central_scheduler` for more info. """ import abc import base64 import json import logging import os import socket from urllib.error import URLError from urllib.parse import urlencode, urljoin, urlparse from urllib.request import Request, urlopen from tenacity import Retrying, stop_after_attempt, wait_fixed from luigi import configuration from luigi.scheduler import RPC_METHODS HAS_UNIX_SOCKET = True HAS_REQUESTS = True try: import requests_unixsocket as requests except ImportError: HAS_UNIX_SOCKET = False try: import requests except ImportError: HAS_REQUESTS = False logger = logging.getLogger("luigi-interface") # TODO: 'interface'? def _urljoin(base, url): """ Join relative URLs to base URLs like urllib.parse.urljoin but support arbitrary URIs (esp. 'http+unix://'). base part is fixed or mounted point, every url contains full base part. """ parsed = urlparse(base) scheme = parsed.scheme return urlparse(urljoin(parsed._replace(scheme="http").geturl(), parsed.path + (url if url[0] == "/" else "/" + url)))._replace(scheme=scheme).geturl() class RPCError(Exception): def __init__(self, message, sub_exception=None): super(RPCError, self).__init__(message) self.sub_exception = sub_exception class _FetcherInterface(metaclass=abc.ABCMeta): @abc.abstractmethod def fetch(self, full_url, body, timeout): pass @abc.abstractmethod def close(self): pass class URLLibFetcher(_FetcherInterface): raises = (URLError, socket.timeout) def _create_request(self, full_url, body=None): # when full_url contains basic auth info, extract it and set the Authorization header url = urlparse(full_url) if url.username: # base64 encoding of username:password auth = base64.b64encode("{}:{}".format(url.username, url.password or "").encode("utf-8")) auth = auth.decode("utf-8") # update full_url and create a request object with the auth header set full_url = url._replace(netloc=url.netloc.split("@", 1)[-1]).geturl() req = Request(full_url) req.add_header("Authorization", "Basic {}".format(auth)) else: req = Request(full_url) # add the request body if body: req.data = urlencode(body).encode("utf-8") return req def fetch(self, full_url, body, timeout): req = self._create_request(full_url, body=body) return urlopen(req, timeout=timeout).read().decode("utf-8") def close(self): pass class RequestsFetcher(_FetcherInterface): def __init__(self): from requests import exceptions as requests_exceptions self.raises = requests_exceptions.RequestException self.session = requests.Session() self.process_id = os.getpid() def check_pid(self): # if the process id change changed from when the session was created # a new session needs to be setup since requests isn't multiprocessing safe. if os.getpid() != self.process_id: self.session = requests.Session() self.process_id = os.getpid() def fetch(self, full_url, body, timeout): self.check_pid() resp = self.session.post(full_url, data=body, timeout=timeout) resp.raise_for_status() return resp.text def close(self): self.session.close() class RemoteScheduler: """ Scheduler proxy object. Talks to a RemoteSchedulerResponder. """ def __init__(self, url="http://localhost:8082/", connect_timeout=None): assert not url.startswith("http+unix://") or HAS_UNIX_SOCKET, "You need to install requests-unixsocket for Unix socket support." self._url = url.rstrip("/") config = configuration.get_config() if connect_timeout is None: connect_timeout = config.getfloat("core", "rpc-connect-timeout", 10.0) self._connect_timeout = connect_timeout self._rpc_retry_attempts = config.getint("core", "rpc-retry-attempts", 3) self._rpc_retry_wait = config.getint("core", "rpc-retry-wait", 30) self._rpc_log_retries = config.getboolean("core", "rpc-log-retries", True) if HAS_REQUESTS: self._fetcher = RequestsFetcher() else: self._fetcher = URLLibFetcher() def close(self): self._fetcher.close() def _get_retryer(self): def retry_logging(retry_state): if self._rpc_log_retries: logger.warning("Failed connecting to remote scheduler %r", self._url, exc_info=True) logger.info("Retrying attempt %r of %r (max)" % (retry_state.attempt_number + 1, self._rpc_retry_attempts)) logger.info("Wait for %d seconds" % self._rpc_retry_wait) return Retrying(wait=wait_fixed(self._rpc_retry_wait), stop=stop_after_attempt(self._rpc_retry_attempts), reraise=True, after=retry_logging) def _fetch(self, url_suffix, body): full_url = _urljoin(self._url, url_suffix) scheduler_retry = self._get_retryer() try: response = scheduler_retry(self._fetcher.fetch, full_url, body, self._connect_timeout) except self._fetcher.raises as e: raise RPCError("Errors (%d attempts) when connecting to remote scheduler %r" % (self._rpc_retry_attempts, self._url), e) return response def _request(self, url, data, attempts=3, allow_null=True): body = {"data": json.dumps(data)} for _ in range(attempts): page = self._fetch(url, body) response = json.loads(page)["response"] if allow_null or response is not None: return response raise RPCError("Received null response from remote scheduler %r" % self._url) for method_name, method in RPC_METHODS.items(): setattr(RemoteScheduler, method_name, method) ================================================ FILE: luigi/safe_extractor.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ This module provides a class `SafeExtractor` that offers a secure way to extract tar files while mitigating path traversal vulnerabilities, which can occur when files inside the archive are crafted to escape the intended extraction directory. The `SafeExtractor` ensures that the extracted file paths are validated before extraction to prevent malicious archives from extracting files outside the intended directory. Classes: SafeExtractor: A class to securely extract tar files with protection against path traversal attacks. Usage Example: extractor = SafeExtractor("/desired/directory") extractor.safe_extract("archive.tar") """ import os import tarfile class SafeExtractor: """ A class to safely extract tar files, ensuring that no path traversal vulnerabilities are exploited. Attributes: path (str): The directory to extract files into. Methods: _is_within_directory(directory, target): Checks if a target path is within a given directory. safe_extract(tar_path, members=None, \\*, numeric_owner=False): Safely extracts the contents of a tar file to the specified directory. """ def __init__(self, path="."): """ Initializes the SafeExtractor with the specified directory path. Args: path (str): The directory to extract files into. Defaults to the current directory. """ self.path = path @staticmethod def _is_within_directory(directory, target): """ Checks if a target path is within a given directory. Args: directory (str): The directory to check against. target (str): The target path to check. Returns: bool: True if the target path is within the directory, False otherwise. """ abs_directory = os.path.abspath(directory) abs_target = os.path.abspath(target) prefix = os.path.commonprefix([abs_directory, abs_target]) return prefix == abs_directory def safe_extract(self, tar_path, members=None, *, numeric_owner=False): """ Safely extracts the contents of a tar file to the specified directory. Args: tar_path (str): The path to the tar file to extract. members (list, optional): A list of members to extract. Defaults to None. numeric_owner (bool, optional): If True, only the numeric owner will be used. Defaults to False. Raises: RuntimeError: If a path traversal attempt is detected. """ with tarfile.open(tar_path, "r") as tar: for member in tar.getmembers(): member_path = os.path.join(self.path, member.name) if not self._is_within_directory(self.path, member_path): raise RuntimeError("Attempted Path Traversal in Tar File") tar.extractall(self.path, members, numeric_owner=numeric_owner) ================================================ FILE: luigi/scheduler.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ The system for scheduling tasks and executing them in order. Deals with dependencies, priorities, resources, etc. The :py:class:`~luigi.worker.Worker` pulls tasks from the scheduler (usually over the REST interface) and executes them. See :doc:`/central_scheduler` for more info. """ import collections import functools import hashlib import inspect import itertools import json import logging import os import pickle import re import time import uuid from collections.abc import MutableSet from luigi import configuration, notifications, parameter from luigi import task_history as history from luigi.batch_notifier import BatchNotifier from luigi.metrics import MetricsCollectors from luigi.parameter import ParameterVisibility from luigi.task import Config from luigi.task_status import BATCH_RUNNING, DISABLED, DONE, FAILED, PENDING, RUNNING, SUSPENDED, UNKNOWN logger = logging.getLogger(__name__) UPSTREAM_RUNNING = "UPSTREAM_RUNNING" UPSTREAM_MISSING_INPUT = "UPSTREAM_MISSING_INPUT" UPSTREAM_FAILED = "UPSTREAM_FAILED" UPSTREAM_DISABLED = "UPSTREAM_DISABLED" UPSTREAM_SEVERITY_ORDER = ( "", UPSTREAM_RUNNING, UPSTREAM_MISSING_INPUT, UPSTREAM_FAILED, UPSTREAM_DISABLED, ) UPSTREAM_SEVERITY_KEY = UPSTREAM_SEVERITY_ORDER.index STATUS_TO_UPSTREAM_MAP = { FAILED: UPSTREAM_FAILED, RUNNING: UPSTREAM_RUNNING, BATCH_RUNNING: UPSTREAM_RUNNING, PENDING: UPSTREAM_MISSING_INPUT, DISABLED: UPSTREAM_DISABLED, } WORKER_STATE_DISABLED = "disabled" WORKER_STATE_ACTIVE = "active" TASK_FAMILY_RE = re.compile(r"([^(_]+)[(_]") RPC_METHODS = {} _retry_policy_fields = [ "retry_count", "disable_hard_timeout", "disable_window", ] RetryPolicy = collections.namedtuple("RetryPolicy", _retry_policy_fields) # type: ignore def _get_empty_retry_policy(): return RetryPolicy(*[None] * len(_retry_policy_fields)) def rpc_method(**request_args): def _rpc_method(fn): # If request args are passed, return this function again for use as # the decorator function with the request args attached. args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, ann = inspect.getfullargspec(fn) assert not varargs first_arg, *all_args = args assert first_arg == "self" defaults = dict(zip(reversed(all_args), reversed(defaults or ()))) required_args = frozenset(arg for arg in all_args if arg not in defaults) fn_name = fn.__name__ @functools.wraps(fn) def rpc_func(self, *args, **kwargs): actual_args = defaults.copy() actual_args.update(dict(zip(all_args, args))) actual_args.update(kwargs) if not all(arg in actual_args for arg in required_args): raise TypeError("{} takes {} arguments ({} given)".format(fn_name, len(all_args), len(actual_args))) return self._request("/api/{}".format(fn_name), actual_args, **request_args) RPC_METHODS[fn_name] = rpc_func return fn return _rpc_method class scheduler(Config): retry_delay = parameter.FloatParameter(default=900.0) remove_delay = parameter.FloatParameter(default=600.0) worker_disconnect_delay = parameter.FloatParameter(default=60.0) state_path = parameter.Parameter(default="/var/lib/luigi-server/state.pickle") batch_emails = parameter.BoolParameter(default=False, description="Send e-mails in batches rather than immediately") # Jobs are disabled if we see more than retry_count failures in disable_window seconds. # These disables last for disable_persist seconds. disable_window = parameter.IntParameter(default=3600) retry_count = parameter.IntParameter(default=999999999) disable_hard_timeout = parameter.IntParameter(default=999999999) disable_persist = parameter.IntParameter(default=86400) max_shown_tasks = parameter.IntParameter(default=100000) max_graph_nodes = parameter.IntParameter(default=100000) record_task_history = parameter.BoolParameter(default=False) prune_on_get_work = parameter.BoolParameter(default=False) pause_enabled = parameter.BoolParameter(default=True) send_messages = parameter.BoolParameter(default=True) metrics_collector = parameter.EnumParameter(enum=MetricsCollectors, default=MetricsCollectors.default) metrics_custom_import = parameter.OptionalStrParameter(default=None) stable_done_cooldown_secs = parameter.IntParameter(default=10, description="Sets cooldown period to avoid running the same task twice") """ Sets a cooldown period in seconds after a task was completed, during this period the same task will not accepted by the scheduler. """ def _get_retry_policy(self): return RetryPolicy(self.retry_count, self.disable_hard_timeout, self.disable_window) def _get_default(x, default): if x is not None: return x else: return default class OrderedSet(MutableSet): """ Standard Python OrderedSet recipe found at http://code.activestate.com/recipes/576694/ Modified to include a peek function to get the last element """ def __init__(self, iterable=None): self.end = end = [] end += [None, end, end] # sentinel node for doubly linked list self.map = {} # key --> [key, prev, next] if iterable is not None: self |= iterable def __len__(self): return len(self.map) def __contains__(self, key): return key in self.map def add(self, key): if key not in self.map: end = self.end curr = end[1] curr[2] = end[1] = self.map[key] = [key, curr, end] def discard(self, key): if key in self.map: key, prev, next = self.map.pop(key) prev[2] = next next[1] = prev def __iter__(self): end = self.end curr = end[2] while curr is not end: yield curr[0] curr = curr[2] def __reversed__(self): end = self.end curr = end[1] while curr is not end: yield curr[0] curr = curr[1] def peek(self, last=True): if not self: raise KeyError("set is empty") key = self.end[1][0] if last else self.end[2][0] return key def pop(self, last=True): key = self.peek(last) self.discard(key) return key def __repr__(self): if not self: return "%s()" % (self.__class__.__name__,) return "%s(%r)" % (self.__class__.__name__, list(self)) def __eq__(self, other): if isinstance(other, OrderedSet): return len(self) == len(other) and list(self) == list(other) return set(self) == set(other) class Task: def __init__( self, task_id, status, deps, resources=None, priority=0, family="", module=None, params=None, param_visibilities=None, accepts_messages=False, tracking_url=None, status_message=None, progress_percentage=None, retry_policy="notoptional", ): self.id = task_id self.stakeholders = set() # workers ids that are somehow related to this task (i.e. don't prune while any of these workers are still active) self.workers = OrderedSet() # workers ids that can perform task - task is 'BROKEN' if none of these workers are active if deps is None: self.deps = set() else: self.deps = set(deps) self.status = status # PENDING, RUNNING, FAILED or DONE self.time = time.time() # Timestamp when task was first added self.updated = self.time self.retry = None self.remove = None self.worker_running = None # the worker id that is currently running the task or None self.time_running = None # Timestamp when picked up by worker self.expl = None self.priority = priority self.resources = _get_default(resources, {}) self.family = family self.module = module self.param_visibilities = _get_default(param_visibilities, {}) self.params = {} self.public_params = {} self.hidden_params = {} self.set_params(params) self.accepts_messages = accepts_messages self.retry_policy = retry_policy self.failures = collections.deque() self.first_failure_time = None self.tracking_url = tracking_url self.status_message = status_message self.progress_percentage = progress_percentage self.scheduler_message_responses = {} self.scheduler_disable_time = None self.runnable = False self.batchable = False self.batch_id = None def __repr__(self): return "Task(%r)" % vars(self) def set_params(self, params): self.params = _get_default(params, {}) self.public_params = { key: value for key, value in self.params.items() if self.param_visibilities.get(key, ParameterVisibility.PUBLIC) == ParameterVisibility.PUBLIC } self.hidden_params = { key: value for key, value in self.params.items() if self.param_visibilities.get(key, ParameterVisibility.PUBLIC) == ParameterVisibility.HIDDEN } # TODO(2017-08-10) replace this function with direct calls to batchable # this only exists for backward compatibility def is_batchable(self): try: return self.batchable except AttributeError: return False def add_failure(self): """ Add a failure event with the current timestamp. """ failure_time = time.time() if not self.first_failure_time: self.first_failure_time = failure_time self.failures.append(failure_time) def num_failures(self): """ Return the number of failures in the window. """ min_time = time.time() - self.retry_policy.disable_window while self.failures and self.failures[0] < min_time: self.failures.popleft() return len(self.failures) def has_excessive_failures(self): if self.first_failure_time is not None: if time.time() >= self.first_failure_time + self.retry_policy.disable_hard_timeout: return True logger.debug("%s task num failures is %s and limit is %s", self.id, self.num_failures(), self.retry_policy.retry_count) if self.num_failures() >= self.retry_policy.retry_count: logger.debug("%s task num failures limit(%s) is exceeded", self.id, self.retry_policy.retry_count) return True return False def clear_failures(self): """ Clear the failures history """ self.failures.clear() self.first_failure_time = None @property def pretty_id(self): param_str = ", ".join("{}={}".format(key, value) for key, value in sorted(self.public_params.items())) return "{}({})".format(self.family, param_str) class Worker: """ Structure for tracking worker activity and keeping their references. """ def __init__(self, worker_id, last_active=None): self.id = worker_id self.reference = None # reference to the worker in the real world. (Currently a dict containing just the host) self.last_active = last_active or time.time() # seconds since epoch self.last_get_work = None self.started = time.time() # seconds since epoch self.tasks = set() # task objects self.info = {} self.disabled = False self.rpc_messages = [] def add_info(self, info): self.info.update(info) def update(self, worker_reference, get_work=False): if worker_reference: self.reference = worker_reference self.last_active = time.time() if get_work: self.last_get_work = time.time() def prune(self, config): # Delete workers that haven't said anything for a while (probably killed) if self.last_active + config.worker_disconnect_delay < time.time(): return True def get_tasks(self, state, *statuses): num_self_tasks = len(self.tasks) num_state_tasks = sum(len(state._status_tasks[status]) for status in statuses) if num_self_tasks < num_state_tasks: return filter(lambda task: task.status in statuses, self.tasks) else: return filter(lambda task: self.id in task.workers, state.get_active_tasks_by_status(*statuses)) def is_trivial_worker(self, state): """ If it's not an assistant having only tasks that are without requirements. We have to pass the state parameter for optimization reasons. """ if self.assistant: return False return all(not task.resources for task in self.get_tasks(state, PENDING)) @property def assistant(self): return self.info.get("assistant", False) @property def enabled(self): return not self.disabled @property def state(self): if self.enabled: return WORKER_STATE_ACTIVE else: return WORKER_STATE_DISABLED def add_rpc_message(self, name, **kwargs): # the message has the format {'name': , 'kwargs': } self.rpc_messages.append({"name": name, "kwargs": kwargs}) def fetch_rpc_messages(self): messages = self.rpc_messages[:] del self.rpc_messages[:] return messages def __str__(self): return self.id class SimpleTaskState: """ Keep track of the current state and handle persistence. The point of this class is to enable other ways to keep state, eg. by using a database These will be implemented by creating an abstract base class that this and other classes inherit from. """ def __init__(self, state_path): self._state_path = state_path self._tasks = {} # map from id to a Task object self._status_tasks = collections.defaultdict(dict) self._active_workers = {} # map from id to a Worker object self._task_batchers = {} self._metrics_collector = None def get_state(self): return self._tasks, self._active_workers, self._task_batchers def set_state(self, state): self._tasks, self._active_workers = state[:2] if len(state) >= 3: self._task_batchers = state[2] def dump(self): try: with open(self._state_path, "wb") as fobj: pickle.dump(self.get_state(), fobj) except IOError: logger.warning("Failed saving scheduler state", exc_info=1) else: logger.info("Saved state in %s", self._state_path) # prone to lead to crashes when old state is unpickled with updated code. TODO some kind of version control? def load(self): if os.path.exists(self._state_path): logger.info("Attempting to load state from %s", self._state_path) try: with open(self._state_path, "rb") as fobj: state = pickle.load(fobj) except BaseException: logger.exception("Error when loading state. Starting from empty state.") return self.set_state(state) self._status_tasks = collections.defaultdict(dict) for task in self._tasks.values(): self._status_tasks[task.status][task.id] = task else: logger.info("No prior state file exists at %s. Starting with empty state", self._state_path) def get_active_tasks(self): return self._tasks.values() def get_active_tasks_by_status(self, *statuses): return itertools.chain.from_iterable(self._status_tasks[status].values() for status in statuses) def get_active_task_count_for_status(self, status): if status: return len(self._status_tasks[status]) else: return len(self._tasks) def get_batch_running_tasks(self, batch_id): assert batch_id is not None return [task for task in self.get_active_tasks_by_status(BATCH_RUNNING) if task.batch_id == batch_id] def set_batcher(self, worker_id, family, batcher_args, max_batch_size): self._task_batchers.setdefault(worker_id, {}) self._task_batchers[worker_id][family] = (batcher_args, max_batch_size) def get_batcher(self, worker_id, family): return self._task_batchers.get(worker_id, {}).get(family, (None, 1)) def num_pending_tasks(self): """ Return how many tasks are PENDING + RUNNING. O(1). """ return len(self._status_tasks[PENDING]) + len(self._status_tasks[RUNNING]) def get_task(self, task_id, default=None, setdefault=None): if setdefault: task = self._tasks.setdefault(task_id, setdefault) self._status_tasks[task.status][task.id] = task return task else: return self._tasks.get(task_id, default) def has_task(self, task_id): return task_id in self._tasks def re_enable(self, task, config=None): task.scheduler_disable_time = None task.clear_failures() if config: self.set_status(task, FAILED, config) task.clear_failures() def set_batch_running(self, task, batch_id, worker_id): self.set_status(task, BATCH_RUNNING) task.batch_id = batch_id task.worker_running = worker_id task.resources_running = task.resources task.time_running = time.time() def set_status(self, task, new_status, config=None): if new_status == FAILED: assert config is not None if new_status == DISABLED and task.status in (RUNNING, BATCH_RUNNING): return remove_on_failure = task.batch_id is not None and not task.batchable if task.status == DISABLED: if new_status == DONE: self.re_enable(task) # don't allow workers to override a scheduler disable elif task.scheduler_disable_time is not None and new_status != DISABLED: return if task.status == RUNNING and task.batch_id is not None and new_status != RUNNING: for batch_task in self.get_batch_running_tasks(task.batch_id): self.set_status(batch_task, new_status, config) batch_task.batch_id = None task.batch_id = None if new_status == FAILED and task.status != DISABLED: task.add_failure() if task.has_excessive_failures(): task.scheduler_disable_time = time.time() new_status = DISABLED if not config.batch_emails: notifications.send_error_email( "Luigi Scheduler: DISABLED {task} due to excessive failures".format(task=task.id), "{task} failed {failures} times in the last {window} seconds, so it is being disabled for {persist} seconds".format( failures=task.retry_policy.retry_count, task=task.id, window=task.retry_policy.disable_window, persist=config.disable_persist, ), ) elif new_status == DISABLED: task.scheduler_disable_time = None if new_status != task.status: self._status_tasks[task.status].pop(task.id) self._status_tasks[new_status][task.id] = task task.status = new_status task.updated = time.time() self.update_metrics(task, config) if new_status == FAILED: task.retry = time.time() + config.retry_delay if remove_on_failure: task.remove = time.time() def fail_dead_worker_task(self, task, config, assistants): # If a running worker disconnects, tag all its jobs as FAILED and subject it to the same retry logic if task.status in (BATCH_RUNNING, RUNNING) and task.worker_running and task.worker_running not in task.stakeholders | assistants: logger.info( "Task %r is marked as running by disconnected worker %r -> marking as FAILED with retry delay of %rs", task.id, task.worker_running, config.retry_delay, ) task.worker_running = None self.set_status(task, FAILED, config) task.retry = time.time() + config.retry_delay def update_status(self, task, config): # Mark tasks with no remaining active stakeholders for deletion if (not task.stakeholders) and (task.remove is None) and (task.status != RUNNING): # We don't check for the RUNNING case, because that is already handled # by the fail_dead_worker_task function. logger.debug("Task %r has no stakeholders anymore -> might remove task in %s seconds", task.id, config.remove_delay) task.remove = time.time() + config.remove_delay # Re-enable task after the disable time expires if task.status == DISABLED and task.scheduler_disable_time is not None: if time.time() - task.scheduler_disable_time > config.disable_persist: self.re_enable(task, config) # Reset FAILED tasks to PENDING if max timeout is reached, and retry delay is >= 0 if task.status == FAILED and config.retry_delay >= 0 and task.retry < time.time(): self.set_status(task, PENDING, config) def may_prune(self, task): return task.remove and time.time() >= task.remove def inactivate_tasks(self, delete_tasks): # The terminology is a bit confusing: we used to "delete" tasks when they became inactive, # but with a pluggable state storage, you might very well want to keep some history of # older tasks as well. That's why we call it "inactivate" (as in the verb) for task in delete_tasks: task_obj = self._tasks.pop(task) self._status_tasks[task_obj.status].pop(task) def get_active_workers(self, last_active_lt=None, last_get_work_gt=None): for worker in self._active_workers.values(): if last_active_lt is not None and worker.last_active >= last_active_lt: continue last_get_work = worker.last_get_work if last_get_work_gt is not None and (last_get_work is None or last_get_work <= last_get_work_gt): continue yield worker def get_assistants(self, last_active_lt=None): return filter(lambda w: w.assistant, self.get_active_workers(last_active_lt)) def get_worker_ids(self): return self._active_workers.keys() # only used for unit tests def get_worker(self, worker_id): return self._active_workers.setdefault(worker_id, Worker(worker_id)) def inactivate_workers(self, delete_workers): # Mark workers as inactive for worker in delete_workers: self._active_workers.pop(worker) self._remove_workers_from_tasks(delete_workers) def _remove_workers_from_tasks(self, workers, remove_stakeholders=True): for task in self.get_active_tasks(): if remove_stakeholders: task.stakeholders.difference_update(workers) task.workers -= workers def disable_workers(self, worker_ids): self._remove_workers_from_tasks(worker_ids, remove_stakeholders=False) for worker_id in worker_ids: worker = self.get_worker(worker_id) worker.disabled = True worker.tasks.clear() def update_metrics(self, task, config): if task.status == DISABLED: self._metrics_collector.handle_task_disabled(task, config) elif task.status == DONE: self._metrics_collector.handle_task_done(task) elif task.status == FAILED: self._metrics_collector.handle_task_failed(task) class Scheduler: """ Async scheduler that can handle multiple workers, etc. Can be run locally or on a server (using RemoteScheduler + server.Server). """ def __init__(self, config=None, resources=None, task_history_impl=None, **kwargs): """ Keyword Arguments: :param config: an object of class "scheduler" or None (in which the global instance will be used) :param resources: a dict of str->int constraints :param task_history_impl: ignore config and use this object as the task history """ self._config = config or scheduler(**kwargs) self._state = SimpleTaskState(self._config.state_path) if task_history_impl: self._task_history = task_history_impl elif self._config.record_task_history: from luigi import db_task_history # Needs sqlalchemy, thus imported here self._task_history = db_task_history.DbTaskHistory() else: self._task_history = history.NopHistory() self._resources = resources or configuration.get_config().getintdict("resources") # TODO: Can we make this a Parameter? self._make_task = functools.partial(Task, retry_policy=self._config._get_retry_policy()) self._worker_requests = {} self._paused = False if self._config.batch_emails: self._email_batcher = BatchNotifier() self._state._metrics_collector = MetricsCollectors.get(self._config.metrics_collector, self._config.metrics_custom_import) def load(self): self._state.load() def dump(self): self._state.dump() if self._config.batch_emails: self._email_batcher.send_email() @rpc_method() def prune(self): logger.debug("Starting pruning of task graph") self._prune_workers() self._prune_tasks() self._prune_emails() logger.debug("Done pruning task graph") def _prune_workers(self): remove_workers = [] for worker in self._state.get_active_workers(): if worker.prune(self._config): logger.debug("Worker %s timed out (no contact for >=%ss)", worker, self._config.worker_disconnect_delay) remove_workers.append(worker.id) self._state.inactivate_workers(remove_workers) def _prune_tasks(self): assistant_ids = {w.id for w in self._state.get_assistants()} remove_tasks = [] for task in self._state.get_active_tasks(): self._state.fail_dead_worker_task(task, self._config, assistant_ids) self._state.update_status(task, self._config) if self._state.may_prune(task): logger.info("Removing task %r", task.id) remove_tasks.append(task.id) self._state.inactivate_tasks(remove_tasks) def _prune_emails(self): if self._config.batch_emails: self._email_batcher.update() def _update_worker(self, worker_id, worker_reference=None, get_work=False): # Keep track of whenever the worker was last active. # For convenience also return the worker object. worker = self._state.get_worker(worker_id) worker.update(worker_reference, get_work=get_work) return worker def _update_priority(self, task, prio, worker): """ Update priority of the given task. Priority can only be increased. If the task doesn't exist, a placeholder task is created to preserve priority when the task is later scheduled. """ task.priority = prio = max(prio, task.priority) for dep in task.deps or []: t = self._state.get_task(dep) if t is not None and prio > t.priority: self._update_priority(t, prio, worker) @rpc_method() def add_task_batcher(self, worker, task_family, batched_args, max_batch_size=float("inf")): self._state.set_batcher(worker, task_family, batched_args, max_batch_size) @rpc_method() def forgive_failures(self, task_id=None): status = PENDING task = self._state.get_task(task_id) if task is None: return {"task_id": task_id, "status": None} # we forgive only failures if task.status == FAILED: # forgive but do not forget self._update_task_history(task, status) self._state.set_status(task, status, self._config) return {"task_id": task_id, "status": task.status} @rpc_method() def mark_as_done(self, task_id=None): status = DONE task = self._state.get_task(task_id) if task is None: return {"task_id": task_id, "status": None} # we can force mark DONE for running or failed tasks if task.status in {RUNNING, FAILED, DISABLED}: self._update_task_history(task, status) self._state.set_status(task, status, self._config) return {"task_id": task_id, "status": task.status} @rpc_method() def add_task( self, task_id=None, status=PENDING, runnable=True, deps=None, new_deps=None, expl=None, resources=None, priority=0, family="", module=None, params=None, param_visibilities=None, accepts_messages=False, assistant=False, tracking_url=None, worker=None, batchable=None, batch_id=None, retry_policy_dict=None, owners=None, **kwargs, ): """ * add task identified by task_id if it doesn't exist * if deps is not None, update dependency list * update status of task * add additional workers/stakeholders * update priority when needed """ assert worker is not None worker_id = worker worker = self._update_worker(worker_id) resources = {} if resources is None else resources.copy() if retry_policy_dict is None: retry_policy_dict = {} retry_policy = self._generate_retry_policy(retry_policy_dict) if worker.enabled: _default_task = self._make_task( task_id=task_id, status=PENDING, deps=deps, resources=resources, priority=priority, family=family, module=module, params=params, param_visibilities=param_visibilities, ) else: _default_task = None task = self._state.get_task(task_id, setdefault=_default_task) if task is None or (task.status != RUNNING and not worker.enabled): return # Ignore claims that the task is PENDING if it very recently was marked as DONE. if status == PENDING and task.status == DONE and (time.time() - task.updated) < self._config.stable_done_cooldown_secs: return # for setting priority, we'll sometimes create tasks with unset family and params if not task.family: task.family = family if not getattr(task, "module", None): task.module = module if not getattr(task, "param_visibilities", None): task.param_visibilities = _get_default(param_visibilities, {}) if not task.params: task.set_params(params) if batch_id is not None: task.batch_id = batch_id if status == RUNNING and not task.worker_running: task.worker_running = worker_id if batch_id: # copy resources_running of the first batch task batch_tasks = self._state.get_batch_running_tasks(batch_id) task.resources_running = batch_tasks[0].resources_running.copy() task.time_running = time.time() if accepts_messages is not None: task.accepts_messages = accepts_messages if tracking_url is not None or task.status != RUNNING: task.tracking_url = tracking_url if task.batch_id is not None: for batch_task in self._state.get_batch_running_tasks(task.batch_id): batch_task.tracking_url = tracking_url if batchable is not None: task.batchable = batchable if task.remove is not None: task.remove = None # unmark task for removal so it isn't removed after being added if expl is not None: task.expl = expl if task.batch_id is not None: for batch_task in self._state.get_batch_running_tasks(task.batch_id): batch_task.expl = expl task_is_not_running = task.status not in (RUNNING, BATCH_RUNNING) task_started_a_run = status in (DONE, FAILED, RUNNING) running_on_this_worker = task.worker_running == worker_id if task_is_not_running or (task_started_a_run and running_on_this_worker) or new_deps: # don't allow re-scheduling of task while it is running, it must either fail or succeed on the worker actually running it if status != task.status or status == PENDING: # Update the DB only if there was a acctual change, to prevent noise. # We also check for status == PENDING b/c that's the default value # (so checking for status != task.status woule lie) self._update_task_history(task, status) self._state.set_status(task, PENDING if status == SUSPENDED else status, self._config) if status == FAILED and self._config.batch_emails: batched_params, _ = self._state.get_batcher(worker_id, family) if batched_params: unbatched_params = {param: value for param, value in task.params.items() if param not in batched_params} else: unbatched_params = task.params try: expl_raw = json.loads(expl) except ValueError: expl_raw = expl self._email_batcher.add_failure(task.pretty_id, task.family, unbatched_params, expl_raw, owners) if task.status == DISABLED: self._email_batcher.add_disable(task.pretty_id, task.family, unbatched_params, owners) if deps is not None: task.deps = set(deps) if new_deps is not None: task.deps.update(new_deps) if resources is not None: task.resources = resources if worker.enabled and not assistant: task.stakeholders.add(worker_id) # Task dependencies might not exist yet. Let's create dummy tasks for them for now. # Otherwise the task dependencies might end up being pruned if scheduling takes a long time for dep in task.deps or []: t = self._state.get_task(dep, setdefault=self._make_task(task_id=dep, status=UNKNOWN, deps=None, priority=priority)) t.stakeholders.add(worker_id) self._update_priority(task, priority, worker_id) # Because some tasks (non-dynamic dependencies) are `_make_task`ed # before we know their retry_policy, we always set it here task.retry_policy = retry_policy if runnable and status != FAILED and worker.enabled: task.workers.add(worker_id) self._state.get_worker(worker_id).tasks.add(task) task.runnable = runnable @rpc_method() def announce_scheduling_failure(self, task_name, family, params, expl, owners, **kwargs): if not self._config.batch_emails: return worker_id = kwargs["worker"] batched_params, _ = self._state.get_batcher(worker_id, family) if batched_params: unbatched_params = {param: value for param, value in params.items() if param not in batched_params} else: unbatched_params = params self._email_batcher.add_scheduling_fail(task_name, family, unbatched_params, expl, owners) @rpc_method() def add_worker(self, worker, info, **kwargs): self._state.get_worker(worker).add_info(info) @rpc_method() def disable_worker(self, worker): self._state.disable_workers({worker}) @rpc_method() def set_worker_processes(self, worker, n): self._state.get_worker(worker).add_rpc_message("set_worker_processes", n=n) @rpc_method() def send_scheduler_message(self, worker, task, content): if not self._config.send_messages: return {"message_id": None} message_id = str(uuid.uuid4()) self._state.get_worker(worker).add_rpc_message("dispatch_scheduler_message", task_id=task, message_id=message_id, content=content) return {"message_id": message_id} @rpc_method() def add_scheduler_message_response(self, task_id, message_id, response): if self._state.has_task(task_id): task = self._state.get_task(task_id) task.scheduler_message_responses[message_id] = response @rpc_method() def get_scheduler_message_response(self, task_id, message_id): response = None if self._state.has_task(task_id): task = self._state.get_task(task_id) response = task.scheduler_message_responses.pop(message_id, None) return {"response": response} @rpc_method() def has_task_history(self): return self._config.record_task_history @rpc_method() def is_pause_enabled(self): return {"enabled": self._config.pause_enabled} @rpc_method() def is_paused(self): return {"paused": self._paused} @rpc_method() def pause(self): if self._config.pause_enabled: self._paused = True @rpc_method() def unpause(self): if self._config.pause_enabled: self._paused = False @rpc_method() def update_resources(self, **resources): if self._resources is None: self._resources = {} self._resources.update(resources) @rpc_method() def update_resource(self, resource, amount): if not isinstance(amount, int) or amount < 0: return False self._resources[resource] = amount return True def _generate_retry_policy(self, task_retry_policy_dict): retry_policy_dict = self._config._get_retry_policy()._asdict() retry_policy_dict.update({k: v for k, v in task_retry_policy_dict.items() if v is not None}) return RetryPolicy(**retry_policy_dict) def _has_resources(self, needed_resources, used_resources): if needed_resources is None: return True available_resources = self._resources or {} for resource, amount in needed_resources.items(): if amount + used_resources[resource] > available_resources.get(resource, 1): return False return True def _used_resources(self): used_resources = collections.defaultdict(int) if self._resources is not None: for task in self._state.get_active_tasks_by_status(RUNNING): resources_running = getattr(task, "resources_running", task.resources) if resources_running: for resource, amount in resources_running.items(): used_resources[resource] += amount return used_resources def _rank(self, task): """ Return worker's rank function for task scheduling. :return: """ return task.priority, -task.time def _schedulable(self, task): if task.status != PENDING: return False for dep in task.deps: dep_task = self._state.get_task(dep, default=None) if dep_task is None or dep_task.status != DONE: return False return True def _reset_orphaned_batch_running_tasks(self, worker_id): running_batch_ids = {task.batch_id for task in self._state.get_active_tasks_by_status(RUNNING) if task.worker_running == worker_id} orphaned_tasks = [ task for task in self._state.get_active_tasks_by_status(BATCH_RUNNING) if task.worker_running == worker_id and task.batch_id not in running_batch_ids ] for task in orphaned_tasks: self._state.set_status(task, PENDING) @rpc_method() def count_pending(self, worker): worker_id, worker = worker, self._state.get_worker(worker) num_pending, num_unique_pending, num_pending_last_scheduled = 0, 0, 0 running_tasks = [] upstream_status_table = {} for task in worker.get_tasks(self._state, RUNNING): if self._upstream_status(task.id, upstream_status_table) == UPSTREAM_DISABLED: continue # Return a list of currently running tasks to the client, # makes it easier to troubleshoot other_worker = self._state.get_worker(task.worker_running) if other_worker is not None: more_info = {"task_id": task.id, "worker": str(other_worker)} more_info.update(other_worker.info) running_tasks.append(more_info) for task in worker.get_tasks(self._state, PENDING, FAILED): if self._upstream_status(task.id, upstream_status_table) == UPSTREAM_DISABLED: continue num_pending += 1 num_unique_pending += int(len(task.workers) == 1) num_pending_last_scheduled += int(task.workers.peek(last=True) == worker_id) return { "n_pending_tasks": num_pending, "n_unique_pending": num_unique_pending, "n_pending_last_scheduled": num_pending_last_scheduled, "worker_state": worker.state, "running_tasks": running_tasks, } @rpc_method(allow_null=False) def get_work(self, host=None, assistant=False, current_tasks=None, worker=None, **kwargs): # TODO: remove any expired nodes # Algo: iterate over all nodes, find the highest priority node no dependencies and available # resources. # Resource checking looks both at currently available resources and at which resources would # be available if all running tasks died and we rescheduled all workers greedily. We do both # checks in order to prevent a worker with many low-priority tasks from starving other # workers with higher priority tasks that share the same resources. # TODO: remove tasks that can't be done, figure out if the worker has absolutely # nothing it can wait for if self._config.prune_on_get_work: self.prune() assert worker is not None worker_id = worker worker = self._update_worker(worker_id, worker_reference={"host": host}, get_work=True) if not worker.enabled: reply = { "n_pending_tasks": 0, "running_tasks": [], "task_id": None, "n_unique_pending": 0, "worker_state": worker.state, } return reply if assistant: self.add_worker(worker_id, [("assistant", assistant)]) batched_params, unbatched_params, batched_tasks, max_batch_size = None, None, [], 1 best_task = None if current_tasks is not None: ct_set = set(current_tasks) for task in sorted(self._state.get_active_tasks_by_status(RUNNING), key=self._rank): if task.worker_running == worker_id and task.id not in ct_set: best_task = task if current_tasks is not None: # batch running tasks that weren't claimed since the last get_work go back in the pool self._reset_orphaned_batch_running_tasks(worker_id) greedy_resources = collections.defaultdict(int) worker = self._state.get_worker(worker_id) if self._paused: relevant_tasks = [] elif worker.is_trivial_worker(self._state): relevant_tasks = worker.get_tasks(self._state, PENDING, RUNNING) used_resources = collections.defaultdict(int) greedy_workers = dict() # If there's no resources, then they can grab any task else: relevant_tasks = self._state.get_active_tasks_by_status(PENDING, RUNNING) used_resources = self._used_resources() activity_limit = time.time() - self._config.worker_disconnect_delay active_workers = self._state.get_active_workers(last_get_work_gt=activity_limit) greedy_workers = dict((worker.id, worker.info.get("workers", 1)) for worker in active_workers) tasks = list(relevant_tasks) tasks.sort(key=self._rank, reverse=True) for task in tasks: if ( best_task and batched_params and task.family == best_task.family and len(batched_tasks) < max_batch_size and task.is_batchable() and all(task.params.get(name) == value for name, value in unbatched_params.items()) and task.resources == best_task.resources and self._schedulable(task) ): for name, params in batched_params.items(): params.append(task.params.get(name)) batched_tasks.append(task) if best_task: continue if task.status == RUNNING and (task.worker_running in greedy_workers): greedy_workers[task.worker_running] -= 1 for resource, amount in (getattr(task, "resources_running", task.resources) or {}).items(): greedy_resources[resource] += amount if self._schedulable(task) and self._has_resources(task.resources, greedy_resources): in_workers = (assistant and task.runnable) or worker_id in task.workers if in_workers and self._has_resources(task.resources, used_resources): best_task = task batch_param_names, max_batch_size = self._state.get_batcher(worker_id, task.family) if batch_param_names and task.is_batchable(): try: batched_params = {name: [task.params[name]] for name in batch_param_names} unbatched_params = {name: value for name, value in task.params.items() if name not in batched_params} batched_tasks.append(task) except KeyError: batched_params, unbatched_params = None, None else: workers = itertools.chain(task.workers, [worker_id]) if assistant else task.workers for task_worker in workers: if greedy_workers.get(task_worker, 0) > 0: # use up a worker greedy_workers[task_worker] -= 1 # keep track of the resources used in greedy scheduling for resource, amount in (task.resources or {}).items(): greedy_resources[resource] += amount break reply = self.count_pending(worker_id) if len(batched_tasks) > 1: batch_string = "|".join(task.id for task in batched_tasks) batch_id = hashlib.new("md5", batch_string.encode("utf-8"), usedforsecurity=False).hexdigest() for task in batched_tasks: self._state.set_batch_running(task, batch_id, worker_id) combined_params = best_task.params.copy() combined_params.update(batched_params) reply["task_id"] = None reply["task_family"] = best_task.family reply["task_module"] = getattr(best_task, "module", None) reply["task_params"] = combined_params reply["batch_id"] = batch_id reply["batch_task_ids"] = [task.id for task in batched_tasks] elif best_task: self.update_metrics_task_started(best_task) self._state.set_status(best_task, RUNNING, self._config) best_task.worker_running = worker_id best_task.resources_running = best_task.resources.copy() best_task.time_running = time.time() self._update_task_history(best_task, RUNNING, host=host) reply["task_id"] = best_task.id reply["task_family"] = best_task.family reply["task_module"] = getattr(best_task, "module", None) reply["task_params"] = best_task.params else: reply["task_id"] = None return reply @rpc_method(attempts=1) def ping(self, **kwargs): worker_id = kwargs["worker"] worker = self._update_worker(worker_id) return {"rpc_messages": worker.fetch_rpc_messages()} def _upstream_status(self, task_id, upstream_status_table): if task_id in upstream_status_table: return upstream_status_table[task_id] elif self._state.has_task(task_id): task_stack = [task_id] while task_stack: dep_id = task_stack.pop() dep = self._state.get_task(dep_id) if dep: if dep.status == DONE: continue if dep_id not in upstream_status_table: if dep.status == PENDING and dep.deps: task_stack += [dep_id] + list(dep.deps) upstream_status_table[dep_id] = "" # will be updated postorder else: dep_status = STATUS_TO_UPSTREAM_MAP.get(dep.status, "") upstream_status_table[dep_id] = dep_status elif upstream_status_table[dep_id] == "" and dep.deps: # This is the postorder update step when we set the # status based on the previously calculated child elements status = max((upstream_status_table.get(a_task_id, "") for a_task_id in dep.deps), key=UPSTREAM_SEVERITY_KEY) upstream_status_table[dep_id] = status return upstream_status_table[dep_id] def _serialize_task(self, task_id, include_deps=True, deps=None): task = self._state.get_task(task_id) ret = { "display_name": task.pretty_id, "status": task.status, "workers": list(task.workers), "worker_running": task.worker_running, "time_running": getattr(task, "time_running", None), "start_time": task.time, "last_updated": getattr(task, "updated", task.time), "params": task.public_params, "name": task.family, "priority": task.priority, "resources": task.resources, "resources_running": getattr(task, "resources_running", None), "tracking_url": getattr(task, "tracking_url", None), "status_message": getattr(task, "status_message", None), "progress_percentage": getattr(task, "progress_percentage", None), } if task.status == DISABLED: ret["re_enable_able"] = task.scheduler_disable_time is not None if include_deps: ret["deps"] = list(task.deps if deps is None else deps) if self._config.send_messages and task.status == RUNNING: ret["accepts_messages"] = task.accepts_messages return ret @rpc_method() def graph(self, **kwargs): self.prune() serialized = {} seen = set() for task in self._state.get_active_tasks(): serialized.update(self._traverse_graph(task.id, seen)) return serialized def _filter_done(self, task_ids): for task_id in task_ids: task = self._state.get_task(task_id) if task is None or task.status != DONE: yield task_id def _traverse_graph(self, root_task_id, seen=None, dep_func=None, include_done=True): """Returns the dependency graph rooted at task_id This does a breadth-first traversal to find the nodes closest to the root before hitting the scheduler.max_graph_nodes limit. :param root_task_id: the id of the graph's root :return: A map of task id to serialized node """ if seen is None: seen = set() elif root_task_id in seen: return {} if dep_func is None: def dep_func(t): return t.deps seen.add(root_task_id) serialized = {} queue = collections.deque([root_task_id]) while queue: task_id = queue.popleft() task = self._state.get_task(task_id) if task is None or not task.family: logger.debug("Missing task for id [%s]", task_id) # NOTE : If a dependency is missing from self._state there is no way to deduce the # task family and parameters. family_match = TASK_FAMILY_RE.match(task_id) family = family_match.group(1) if family_match else UNKNOWN params = {"task_id": task_id} serialized[task_id] = { "deps": [], "status": UNKNOWN, "workers": [], "start_time": UNKNOWN, "params": params, "name": family, "display_name": task_id, "priority": 0, } else: deps = dep_func(task) if not include_done: deps = list(self._filter_done(deps)) serialized[task_id] = self._serialize_task(task_id, deps=deps) for dep in sorted(deps): if dep not in seen: seen.add(dep) queue.append(dep) if task_id != root_task_id: del serialized[task_id]["display_name"] if len(serialized) >= self._config.max_graph_nodes: break return serialized @rpc_method() def dep_graph(self, task_id, include_done=True, **kwargs): self.prune() if not self._state.has_task(task_id): return {} return self._traverse_graph(task_id, include_done=include_done) @rpc_method() def inverse_dep_graph(self, task_id, include_done=True, **kwargs): self.prune() if not self._state.has_task(task_id): return {} inverse_graph = collections.defaultdict(set) for task in self._state.get_active_tasks(): for dep in task.deps: inverse_graph[dep].add(task.id) return self._traverse_graph(task_id, dep_func=lambda t: inverse_graph[t.id], include_done=include_done) @rpc_method() def task_list(self, status="", upstream_status="", limit=True, search=None, max_shown_tasks=None, **kwargs): """ Query for a subset of tasks by status. """ if not search: count_limit = max_shown_tasks or self._config.max_shown_tasks pre_count = self._state.get_active_task_count_for_status(status) if limit and pre_count > count_limit: return {"num_tasks": -1 if upstream_status else pre_count} self.prune() result = {} upstream_status_table = {} # used to memoize upstream status if search is None: def filter_func(_): return True else: terms = search.split() def filter_func(t): return all(term.casefold() in t.pretty_id.casefold() for term in terms) tasks = self._state.get_active_tasks_by_status(status) if status else self._state.get_active_tasks() for task in filter(filter_func, tasks): if task.status != PENDING or not upstream_status or upstream_status == self._upstream_status(task.id, upstream_status_table): serialized = self._serialize_task(task.id, include_deps=False) result[task.id] = serialized if limit and len(result) > (max_shown_tasks or self._config.max_shown_tasks): return {"num_tasks": len(result)} return result def _first_task_display_name(self, worker): task_id = worker.info.get("first_task", "") if self._state.has_task(task_id): return self._state.get_task(task_id).pretty_id else: return task_id @rpc_method() def worker_list(self, include_running=True, **kwargs): self.prune() workers = [ dict( name=worker.id, last_active=worker.last_active, started=worker.started, state=worker.state, first_task_display_name=self._first_task_display_name(worker), num_unread_rpc_messages=len(worker.rpc_messages), **worker.info, ) for worker in self._state.get_active_workers() ] workers.sort(key=lambda worker: worker["started"], reverse=True) if include_running: running = collections.defaultdict(dict) for task in self._state.get_active_tasks_by_status(RUNNING): if task.worker_running: running[task.worker_running][task.id] = self._serialize_task(task.id, include_deps=False) num_pending = collections.defaultdict(int) num_uniques = collections.defaultdict(int) for task in self._state.get_active_tasks_by_status(PENDING): for worker in task.workers: num_pending[worker] += 1 if len(task.workers) == 1: num_uniques[list(task.workers)[0]] += 1 for worker in workers: tasks = running[worker["name"]] worker["num_running"] = len(tasks) worker["num_pending"] = num_pending[worker["name"]] worker["num_uniques"] = num_uniques[worker["name"]] worker["running"] = tasks return workers @rpc_method() def resource_list(self): """ Resources usage info and their consumers (tasks). """ self.prune() resources = [dict(name=resource, num_total=r_dict["total"], num_used=r_dict["used"]) for resource, r_dict in self.resources().items()] if self._resources is not None: consumers = collections.defaultdict(dict) for task in self._state.get_active_tasks_by_status(RUNNING): if task.status == RUNNING and task.resources: for resource, amount in task.resources.items(): consumers[resource][task.id] = self._serialize_task(task.id, include_deps=False) for resource in resources: tasks = consumers[resource["name"]] resource["num_consumer"] = len(tasks) resource["running"] = tasks return resources def resources(self): """get total resources and available ones""" used_resources = self._used_resources() ret = collections.defaultdict(dict) for resource, total in self._resources.items(): ret[resource]["total"] = total if resource in used_resources: ret[resource]["used"] = used_resources[resource] else: ret[resource]["used"] = 0 return ret @rpc_method() def task_search(self, task_str, **kwargs): """ Query for a subset of tasks by task_id. :param task_str: :return: """ self.prune() result = collections.defaultdict(dict) for task in self._state.get_active_tasks(): if task.id.find(task_str) != -1: serialized = self._serialize_task(task.id, include_deps=False) result[task.status][task.id] = serialized return result @rpc_method() def re_enable_task(self, task_id): serialized = {} task = self._state.get_task(task_id) if task and task.status == DISABLED and task.scheduler_disable_time: self._state.re_enable(task, self._config) serialized = self._serialize_task(task_id) return serialized @rpc_method() def fetch_error(self, task_id, **kwargs): if self._state.has_task(task_id): task = self._state.get_task(task_id) return { "taskId": task_id, "error": task.expl, "displayName": task.pretty_id, "taskParams": task.params, "taskModule": task.module, "taskFamily": task.family, } else: return {"taskId": task_id, "error": ""} @rpc_method() def set_task_status_message(self, task_id, status_message): if self._state.has_task(task_id): task = self._state.get_task(task_id) task.status_message = status_message if task.status == RUNNING and task.batch_id is not None: for batch_task in self._state.get_batch_running_tasks(task.batch_id): batch_task.status_message = status_message @rpc_method() def get_task_status_message(self, task_id): if self._state.has_task(task_id): task = self._state.get_task(task_id) return {"taskId": task_id, "statusMessage": task.status_message} else: return {"taskId": task_id, "statusMessage": ""} @rpc_method() def set_task_progress_percentage(self, task_id, progress_percentage): if self._state.has_task(task_id): task = self._state.get_task(task_id) task.progress_percentage = progress_percentage if task.status == RUNNING and task.batch_id is not None: for batch_task in self._state.get_batch_running_tasks(task.batch_id): batch_task.progress_percentage = progress_percentage @rpc_method() def get_task_progress_percentage(self, task_id): if self._state.has_task(task_id): task = self._state.get_task(task_id) return {"taskId": task_id, "progressPercentage": task.progress_percentage} else: return {"taskId": task_id, "progressPercentage": None} @rpc_method() def decrease_running_task_resources(self, task_id, decrease_resources): if self._state.has_task(task_id): task = self._state.get_task(task_id) if task.status != RUNNING: return def decrease(resources, decrease_resources): for resource, decrease_amount in decrease_resources.items(): if decrease_amount > 0 and resource in resources: resources[resource] = max(0, resources[resource] - decrease_amount) decrease(task.resources_running, decrease_resources) if task.batch_id is not None: for batch_task in self._state.get_batch_running_tasks(task.batch_id): decrease(batch_task.resources_running, decrease_resources) @rpc_method() def get_running_task_resources(self, task_id): if self._state.has_task(task_id): task = self._state.get_task(task_id) return {"taskId": task_id, "resources": getattr(task, "resources_running", None)} else: return {"taskId": task_id, "resources": None} def _update_task_history(self, task, status, host=None): try: if status == DONE or status == FAILED: successful = status == DONE self._task_history.task_finished(task, successful) elif status == PENDING: self._task_history.task_scheduled(task) elif status == RUNNING: self._task_history.task_started(task, host) except BaseException: logger.warning("Error saving Task history", exc_info=True) @property def task_history(self): # Used by server.py to expose the calls return self._task_history @rpc_method() def update_metrics_task_started(self, task): self._state._metrics_collector.handle_task_started(task) @rpc_method() def report_task_statistics(self, task_id, statistics): if self._state.has_task(task_id): task = self._state.get_task(task_id) self._state._metrics_collector.handle_task_statistics(task, statistics) ================================================ FILE: luigi/server.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ Simple REST server that takes commands in a JSON payload Interface to the :py:class:`~luigi.scheduler.Scheduler` class. See :doc:`/central_scheduler` for more info. """ # # Description: Added codes for visualization of how long each task takes # running-time until it reaches the next status (failed or done) # At "{base_url}/tasklist", all completed(failed or done) tasks are shown. # At "{base_url}/tasklist", a user can select one specific task to see # how its running-time has changed over time. # At "{base_url}/tasklist/{task_name}", it visualizes a multi-bar graph # that represents the changes of the running-time for a selected task # up to the next status (failed or done). # This visualization let us know how the running-time of the specific task # has changed over time. # # Copyright 2015 Naver Corp. # Author Yeseul Park (yeseul.park@navercorp.com) # import atexit import datetime import importlib import json import logging import os import signal import sys import time import tornado.httpserver import tornado.ioloop import tornado.netutil import tornado.web from luigi import Config, parameter from luigi.scheduler import RPC_METHODS, Scheduler logger = logging.getLogger("luigi.server") class cors(Config): enabled = parameter.BoolParameter(default=False, description="Enables CORS support.") allowed_origins = parameter.ListParameter(default=(), description="A list of allowed origins. Used only if `allow_any_origin` is false.") allow_any_origin = parameter.BoolParameter(default=False, description="Accepts requests from any origin.") allow_null_origin = parameter.BoolParameter(default=False, description="Allows the request to set `null` value of the `Origin` header.") max_age = parameter.IntParameter(default=86400, description="Content of `Access-Control-Max-Age`.") allowed_methods = parameter.Parameter(default="GET, OPTIONS", description="Content of `Access-Control-Allow-Methods`.") allowed_headers = parameter.Parameter(default="Accept, Content-Type, Origin", description="Content of `Access-Control-Allow-Headers`.") exposed_headers = parameter.Parameter(default="", description="Content of `Access-Control-Expose-Headers`.") allow_credentials = parameter.BoolParameter(default=False, description="Indicates that the actual request can include user credentials.") def __init__(self, *args, **kwargs): super(cors, self).__init__(*args, **kwargs) self.allowed_origins = set(i for i in self.allowed_origins if i not in ["*", "null"]) class RPCHandler(tornado.web.RequestHandler): """ Handle remote scheduling calls using rpc.RemoteSchedulerResponder. """ def __init__(self, *args, **kwargs): super(RPCHandler, self).__init__(*args, **kwargs) self._cors_config = cors() def initialize(self, scheduler): self._scheduler = scheduler def options(self, *args): if self._cors_config.enabled: self._handle_cors_preflight() self.set_status(204) self.finish() def get(self, method): if method not in RPC_METHODS: self.send_error(404) return payload = self.get_argument("data", default="{}") arguments = json.loads(payload) if hasattr(self._scheduler, method): result = getattr(self._scheduler, method)(**arguments) if self._cors_config.enabled: self._handle_cors() self.write({"response": result}) # wrap all json response in a dictionary else: self.send_error(404) post = get def _handle_cors_preflight(self): origin = self.request.headers.get("Origin") if not origin: return if origin == "null": if self._cors_config.allow_null_origin: self.set_header("Access-Control-Allow-Origin", "null") self._set_other_cors_headers() else: if self._cors_config.allow_any_origin: self.set_header("Access-Control-Allow-Origin", "*") self._set_other_cors_headers() elif origin in self._cors_config.allowed_origins: self.set_header("Access-Control-Allow-Origin", origin) self._set_other_cors_headers() def _handle_cors(self): origin = self.request.headers.get("Origin") if not origin: return if origin == "null": if self._cors_config.allow_null_origin: self.set_header("Access-Control-Allow-Origin", "null") else: if self._cors_config.allow_any_origin: self.set_header("Access-Control-Allow-Origin", "*") elif origin in self._cors_config.allowed_origins: self.set_header("Access-Control-Allow-Origin", origin) self.set_header("Vary", "Origin") def _set_other_cors_headers(self): self.set_header("Access-Control-Max-Age", str(self._cors_config.max_age)) self.set_header("Access-Control-Allow-Methods", self._cors_config.allowed_methods) self.set_header("Access-Control-Allow-Headers", self._cors_config.allowed_headers) if self._cors_config.allow_credentials: self.set_header("Access-Control-Allow-Credentials", "true") if self._cors_config.exposed_headers: self.set_header("Access-Control-Expose-Headers", self._cors_config.exposed_headers) class BaseTaskHistoryHandler(tornado.web.RequestHandler): def initialize(self, scheduler): self._scheduler = scheduler def get_template_path(self): return importlib.resources.files("templates").name class AllRunHandler(BaseTaskHistoryHandler): def get(self): all_tasks = self._scheduler.task_history.find_all_runs() tasknames = [task.name for task in all_tasks] # show all tasks with their name list to be selected # why all tasks? the duration of the event history of a selected task # can be more than 24 hours. self.render("menu.html", tasknames=tasknames) class SelectedRunHandler(BaseTaskHistoryHandler): def get(self, name): statusResults = {} taskResults = [] # get all tasks that has been updated all_tasks = self._scheduler.task_history.find_all_runs() # get events history for all tasks all_tasks_event_history = self._scheduler.task_history.find_all_events() # build the dictionary tasks with index: id, value: task_name tasks = {task.id: str(task.name) for task in all_tasks} for task in all_tasks_event_history: # if the name of user-selected task is in tasks, get its task_id if tasks.get(task.task_id) == str(name): status = str(task.event_name) if status not in statusResults: statusResults[status] = [] # append the id, task_id, ts, y with 0, next_process with null # for the status(running/failed/done) of the selected task statusResults[status].append(({"id": str(task.id), "task_id": str(task.task_id), "x": from_utc(str(task.ts)), "y": 0, "next_process": ""})) # append the id, task_name, task_id, status, datetime, timestamp # for the selected task taskResults.append( { "id": str(task.id), "taskName": str(name), "task_id": str(task.task_id), "status": str(task.event_name), "datetime": str(task.ts), "timestamp": from_utc(str(task.ts)), } ) statusResults = json.dumps(statusResults) taskResults = json.dumps(taskResults) statusResults = tornado.escape.xhtml_unescape(str(statusResults)) taskResults = tornado.escape.xhtml_unescape(str(taskResults)) self.render("history.html", name=name, statusResults=statusResults, taskResults=taskResults) def from_utc(utcTime, fmt=None): """convert UTC time string to time.struct_time: change datetime.datetime to time, return time.struct_time type""" if fmt is None: try_formats = ["%Y-%m-%d %H:%M:%S.%f", "%Y-%m-%d %H:%M:%S"] else: try_formats = [fmt] for fmt in try_formats: try: time_struct = datetime.datetime.strptime(utcTime, fmt) except ValueError: pass else: date = int(time.mktime(time_struct.timetuple())) return date else: raise ValueError("No UTC format matches {}".format(utcTime)) class RecentRunHandler(BaseTaskHistoryHandler): def get(self): with self._scheduler.task_history._session(None) as session: tasks = self._scheduler.task_history.find_latest_runs(session) self.render("recent.html", tasks=tasks) class ByNameHandler(BaseTaskHistoryHandler): def get(self, name): with self._scheduler.task_history._session(None) as session: tasks = self._scheduler.task_history.find_all_by_name(name, session) self.render("recent.html", tasks=tasks) class ByIdHandler(BaseTaskHistoryHandler): def get(self, id): with self._scheduler.task_history._session(None) as session: task = self._scheduler.task_history.find_task_by_id(id, session) self.render("show.html", task=task) class ByTaskIdHandler(BaseTaskHistoryHandler): def get(self, task_id): with self._scheduler.task_history._session(None) as session: task = self._scheduler.task_history.find_task_by_task_id(task_id, session) self.render("show.html", task=task) class ByParamsHandler(BaseTaskHistoryHandler): def get(self, name): payload = self.get_argument("data", default="{}") arguments = json.loads(payload) with self._scheduler.task_history._session(None) as session: tasks = self._scheduler.task_history.find_all_by_parameters(name, session=session, **arguments) self.render("recent.html", tasks=tasks) class RootPathHandler(BaseTaskHistoryHandler): def get(self): # we omit the leading slash in case the visualizer is behind a different # path (as in a reverse proxy setup) # # For example, if luigi is behind my.app.com/my/luigi/, we want / to # redirect relative (so it goes to my.app.com/my/luigi/static/visualizer/index.html) # instead of absolute (which would be my.app.com/static/visualizer/index.html) self.redirect("static/visualiser/index.html") def head(self): """HEAD endpoint for health checking the scheduler""" self.set_status(204) self.finish() class MetricsHandler(tornado.web.RequestHandler): def initialize(self, scheduler): self._scheduler = scheduler def get(self): metrics_collector = self._scheduler._state._metrics_collector metrics = metrics_collector.generate_latest() if metrics: metrics_collector.configure_http_handler(self) self.write(metrics) def app(scheduler): settings = { "static_path": os.path.join(os.path.dirname(__file__), "static"), "unescape": tornado.escape.xhtml_unescape, "compress_response": True, } handlers = [ (r"/api/(.*)", RPCHandler, {"scheduler": scheduler}), (r"/", RootPathHandler, {"scheduler": scheduler}), (r"/tasklist", AllRunHandler, {"scheduler": scheduler}), (r"/tasklist/(.*?)", SelectedRunHandler, {"scheduler": scheduler}), (r"/history", RecentRunHandler, {"scheduler": scheduler}), (r"/history/by_name/(.*?)", ByNameHandler, {"scheduler": scheduler}), (r"/history/by_id/(.*?)", ByIdHandler, {"scheduler": scheduler}), (r"/history/by_task_id/(.*?)", ByTaskIdHandler, {"scheduler": scheduler}), (r"/history/by_params/(.*?)", ByParamsHandler, {"scheduler": scheduler}), (r"/metrics", MetricsHandler, {"scheduler": scheduler}), ] api_app = tornado.web.Application(handlers, **settings) return api_app def _init_api(scheduler, api_port=None, address=None, unix_socket=None): api_app = app(scheduler) if unix_socket is not None: api_sockets = [tornado.netutil.bind_unix_socket(unix_socket)] else: api_sockets = tornado.netutil.bind_sockets(api_port, address=address) server = tornado.httpserver.HTTPServer(api_app) server.add_sockets(api_sockets) # Return the bound socket names. Useful for connecting client in test scenarios. return [s.getsockname() for s in api_sockets] def run(api_port=8082, address=None, unix_socket=None, scheduler=None): """ Runs one instance of the API server. """ if scheduler is None: scheduler = Scheduler() # load scheduler state scheduler.load() _init_api( scheduler=scheduler, api_port=api_port, address=address, unix_socket=unix_socket, ) # prune work DAG every 60 seconds pruner = tornado.ioloop.PeriodicCallback(scheduler.prune, 60000) pruner.start() def shutdown_handler(signum, frame): exit_handler() sys.exit(0) @atexit.register def exit_handler(): logger.info("Scheduler instance shutting down") scheduler.dump() stop() signal.signal(signal.SIGINT, shutdown_handler) signal.signal(signal.SIGTERM, shutdown_handler) if os.name == "nt": signal.signal(signal.SIGBREAK, shutdown_handler) else: signal.signal(signal.SIGQUIT, shutdown_handler) logger.info("Scheduler starting up") tornado.ioloop.IOLoop.instance().start() def stop(): tornado.ioloop.IOLoop.instance().stop() if __name__ == "__main__": run() ================================================ FILE: luigi/setup_logging.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2018 Vote Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ This module contains helper classes for configuring logging for luigid and workers via command line arguments and options from config files. """ import logging import logging.config import os.path from configparser import NoSectionError from luigi.configuration import LuigiConfigParser, get_config from luigi.freezing import recursively_unfreeze class BaseLogging: config = get_config() @classmethod def _section(cls, opts): """Get logging settings from config file section "logging".""" if isinstance(cls.config, LuigiConfigParser): return False try: logging_config = cls.config["logging"] except (TypeError, KeyError, NoSectionError): return False logging.config.dictConfig(recursively_unfreeze(logging_config)) return True @classmethod def setup(cls, opts=type("opts", (), {"background": None, "logdir": None, "logging_conf_file": None, "log_level": "DEBUG"})): """Setup logging via CLI params and config.""" logger = logging.getLogger("luigi") if cls._configured: logger.info("logging already configured") return False cls._configured = True if cls.config.getboolean("core", "no_configure_logging", False): logger.info("logging disabled in settings") return False configured = cls._cli(opts) if configured: logger = logging.getLogger("luigi") logger.info("logging configured via special settings") return True configured = cls._conf(opts) if configured: logger = logging.getLogger("luigi") logger.info("logging configured via *.conf file") return True configured = cls._section(opts) if configured: logger = logging.getLogger("luigi") logger.info("logging configured via config section") return True configured = cls._default(opts) if configured: logger = logging.getLogger("luigi") logger.info("logging configured by default settings") return configured class DaemonLogging(BaseLogging): """Configure logging for luigid""" _configured = False _log_format = "%(asctime)s %(name)s[%(process)s] %(levelname)s: %(message)s" @classmethod def _cli(cls, opts): """Setup logging via CLI options If `--background` -- set INFO level for root logger. If `--logdir` -- set logging with next params: default Luigi's formatter, INFO level, output in logdir in `luigi-server.log` file """ if opts.background: logging.getLogger().setLevel(logging.INFO) return True if opts.logdir: logging.basicConfig(level=logging.INFO, format=cls._log_format, filename=os.path.join(opts.logdir, "luigi-server.log")) return True return False @classmethod def _conf(cls, opts): """Setup logging via ini-file from logging_conf_file option.""" logging_conf = cls.config.get("core", "logging_conf_file", None) if logging_conf is None: return False if not os.path.exists(logging_conf): # FileNotFoundError added only in Python 3.3 # https://docs.python.org/3/whatsnew/3.3.html#pep-3151-reworking-the-os-and-io-exception-hierarchy raise OSError("Error: Unable to locate specified logging configuration file!") logging.config.fileConfig(logging_conf) return True @classmethod def _default(cls, opts): """Setup default logger""" logging.basicConfig(level=logging.INFO, format=cls._log_format) return True # Part of this logic taken for dropped function "setup_interface_logging" class InterfaceLogging(BaseLogging): """Configure logging for worker""" _configured = False @classmethod def _cli(cls, opts): return False @classmethod def _conf(cls, opts): """Setup logging via ini-file from logging_conf_file option.""" if not opts.logging_conf_file: return False if not os.path.exists(opts.logging_conf_file): # FileNotFoundError added only in Python 3.3 # https://docs.python.org/3/whatsnew/3.3.html#pep-3151-reworking-the-os-and-io-exception-hierarchy raise OSError("Error: Unable to locate specified logging configuration file!") logging.config.fileConfig(opts.logging_conf_file, disable_existing_loggers=False) return True @classmethod def _default(cls, opts): """Setup default logger""" level = getattr(logging, opts.log_level, logging.DEBUG) logger = logging.getLogger("luigi-interface") logger.setLevel(level) stream_handler = logging.StreamHandler() stream_handler.setLevel(level) formatter = logging.Formatter("%(levelname)s: %(message)s") stream_handler.setFormatter(formatter) logger.addHandler(stream_handler) return True ================================================ FILE: luigi/static/visualiser/css/luigi.css ================================================ .nodeCircle { stroke: #fff; stroke-width: 1.5px; } text { font-size:8pt; } .link { stroke: #999; stroke-opacity: .6; } svg { border:1px solid #DDDDDD; overflow: inherit; } .taskRow { word-break:break-all; } @-webkit-keyframes flash { 0%, 50%, 100% { opacity: 1; } 25%, 75% { opacity: 0.2; } } @keyframes flash { 0%, 50%, 100% { opacity: 1; } 25%, 75% { opacity: 0.2; } } .RUNNING { -webkit-animation-duration: 5s; animation-duration: 5s; -webkit-animation-fill-mode: both; animation-fill-mode: both; -webkit-animation-iteration-count: 1; animation-iteration-count: 1; } .live.map { width: 100%; height: 600px; background: #333; } .live.map text { font-weight: 300; font-size: 14px; } .live.map .node rect { stroke-width: 1.5px; stroke: #bbb; fill: #666; } .live.map .status { height: 100%; width: 15px; display: block; float: left; border-top-left-radius: 5px; border-bottom-left-radius: 5px; margin-right: 4px; } .live.map .DONE .status { background-color: #7f7; } .live.map .RUNNING .status { background-color: #7f7; } .live.map .PENDING .status { background-color: #FFFF46; } .live.map .ERROR .status { background-color: #f77; } .live.map .FAILED .status { background-color: #dd4b39; } .live.map .RUNNING .queue { color: #f77; } .live.map .DISABLED .status { background-color: #aaaaaa; } .RUNNING { -webkit-animation-name: flash; animation-name: flash; } .live.map .consumers { margin-right: 2px; } .live.map .consumers, .live.map .name { margin-top: 4px; } .live.map .consumers:after { content: "x"; } .live.map .queue { display: block; float: left; width: 130px; height: 20px; font-size: 12px; margin-top: 2px; } .live.map .node g div { width: 200px; height: 40px; color: #fff; } .live.map .node g div span.consumers { display: inline-block; width: 20px; } .live.map .edgeLabel text { width: 50px; fill: #fff; } .live.map .edgePath path { stroke: #999; stroke-width: 1.5px; fill: #999; } td.details-control { cursor: pointer; } span.status-icon { border-top-left-radius: 2px; border-top-right-radius: 2px; border-bottom-right-radius: 2px; border-bottom-left-radius: 2px; display: inline-block; height: 24px; width: 24px; text-align: center; font-size: 12px; line-height: 24px; } #serverSide { float: right; margin: 4px; } .infoBar { min-height: 80px; } #taskTable_filter { margin-top: 9px; } #loadTaskForm input { width: 20em; } #workerList .box-tools > div { display: inline-block; } #workerList .btn-set-workers > span.caret { margin-left: 4px; } #workerList .box-tools > span.label-unread-worker-messages { margin-right: 6px; vertical-align: middle; font-style: italic; color: red; } #resourceList i.resources-collapse { padding-left: 10px; } #clear-task-filter { margin-left: 20px; cursor: pointer; } #clear-task-filter:hover { opacity: 0.9; } #clear-task-filter:active { box-shadow: inset -2px 3px 1px rgba(0,0,0,0.2); } .sidebar-menu li > a.sidebar-folder { font-weight: bold; background-color: #ddd !important; } .sidebar-menu li > a.sidebar-folder:hover { opacity: 0.9; } .sidebar-menu li > a.sidebar-folder.expanded { background-color: rgb(0, 166, 90) !important; color: white !important; } .popover{ max-width: 100% !important; } ================================================ FILE: luigi/static/visualiser/css/tipsy.css ================================================ .tipsy { font-size: 10px; position: absolute; padding: 5px; z-index: 100000; } .tipsy-inner { background-color: #000; color: #FFF; max-width: 200px; padding: 5px 8px 4px 8px; text-align: center; } /* Rounded corners */ .tipsy-inner { border-radius: 3px; -moz-border-radius: 3px; -webkit-border-radius: 3px; } /* Uncomment for shadow */ .tipsy-inner { box-shadow: 0 0 5px #000000; -webkit-box-shadow: 0 0 5px #000000; -moz-box-shadow: 0 0 5px #000000; } .tipsy-arrow { position: absolute; width: 0; height: 0; line-height: 0; border: 5px dashed #000; } /* Rules to colour arrows */ .tipsy-arrow-n { border-bottom-color: #000; } .tipsy-arrow-s { border-top-color: #000; } .tipsy-arrow-e { border-left-color: #000; } .tipsy-arrow-w { border-right-color: #000; } .tipsy-n .tipsy-arrow { top: 0px; left: 50%; margin-left: -5px; border-bottom-style: solid; border-top: none; border-left-color: transparent; border-right-color: transparent; } .tipsy-nw .tipsy-arrow { top: 0; left: 10px; border-bottom-style: solid; border-top: none; border-left-color: transparent; border-right-color: transparent;} .tipsy-ne .tipsy-arrow { top: 0; right: 10px; border-bottom-style: solid; border-top: none; border-left-color: transparent; border-right-color: transparent;} .tipsy-s .tipsy-arrow { bottom: 0; left: 50%; margin-left: -5px; border-top-style: solid; border-bottom: none; border-left-color: transparent; border-right-color: transparent; } .tipsy-sw .tipsy-arrow { bottom: 0; left: 10px; border-top-style: solid; border-bottom: none; border-left-color: transparent; border-right-color: transparent; } .tipsy-se .tipsy-arrow { bottom: 0; right: 10px; border-top-style: solid; border-bottom: none; border-left-color: transparent; border-right-color: transparent; } .tipsy-e .tipsy-arrow { right: 0; top: 50%; margin-top: -5px; border-left-style: solid; border-right: none; border-top-color: transparent; border-bottom-color: transparent; } .tipsy-w .tipsy-arrow { left: 0; top: 50%; margin-top: -5px; border-right-style: solid; border-left: none; border-top-color: transparent; border-bottom-color: transparent; } ================================================ FILE: luigi/static/visualiser/index.html ================================================ Luigi Task Visualiser
    Pending Tasks ?
    Running Tasks ?
    Batch Running Tasks ?
    Done Tasks ?
    Failed Tasks ?
    Upstream Failure ?
    Disabled Tasks ?
    Upstream Disabled ?
    Name Details Priority Time Actions

    Dependency Graph
    ================================================ FILE: luigi/static/visualiser/js/graph.js ================================================ Graph = (function() { var statusColors = { "FAILED":"#DD0000", "RUNNING":"#0044DD", "BATCH_RUNNING":"#BB00BB", "PENDING":"#EEBB00", "DONE":"#00DD00", "DISABLED":"#808080", "UNKNOWN":"#000000", "TRUNCATED":"#FF00FF" }; /* Line height for items in task status legend */ var legendLineHeight = 20; /* Height of vertical space between nodes */ var nodeHeight = 10; /* Amount of horizontal space given for each node */ var nodeWidth = 200; /* Random horizontal offset for each row */ var jitterWidth = 100; /* Calculate minimum SVG height required for legend */ var legendMaxY = (function () { return Object.keys(statusColors).length * legendLineHeight + ( legendLineHeight / 2 ) })(); var legendWidth = 110; function nodeFromTask(task) { var deps = task.deps; deps.sort(); return { name: task.name, taskId: task.taskId, status: task.status, trackingUrl: this.hashBase + task.taskId, deps: deps, params: task.params, priority: task.priority, depth: -1 }; } /* Convert array to dict by indexing on propertyName */ function uniqueIndexByProperty(data, propertyName) { var nodeIndex = {}; $.each(data, function(i, dataPoint) { nodeIndex[dataPoint[propertyName]] = i; }); return nodeIndex; } /* Create edges between the supplied node using the deps property of each node */ function createDependencyEdges(nodes, nodeIndex) { var edges = []; $.each(nodes, function(i, task) { $.each(task.deps, function(j, dep) { if (nodeIndex[dep]) { edges.push({ source: nodes[nodeIndex[task.taskId]], target: nodes[nodeIndex[dep]] }); } }); }); return edges; } /* Compute the depth of each node for layout purposes */ function computeDepth(nodes, nodeIndex) { var selfDependencies = false function descend(n, depth) { if (n.depth === undefined || depth > n.depth) { n.depth = depth; $.each(n.deps, function(i, dep) { if (nodeIndex[dep]) { var child_node = nodes[nodeIndex[dep]] descend(child_node, depth + 1); if (!selfDependencies && n.name == child_node.name) { selfDependencies = true; } } }); } } descend(nodes[0], 0); return selfDependencies } /* Group tasks, so all tasks with the same name appear at the same depth. */ function groupTasks(nodes) { // compute average assigned depth var taskDepths = {}; $.each(nodes, function(i, n) { if (taskDepths[n.name] === undefined) { taskDepths[n.name] = [n.depth]; } else { taskDepths[n.name].push(n.depth); } }); var averages = []; $.each(taskDepths, function(key, array) { var total = 0; for (var i in array) total += array[i]; var mean = total / array.length; averages.push([key, mean]); }); // sort tasks averages.sort( function(first, second) { return first[1] - second[1]; }); // reassign task depths and node depths var classDepths = {} $.each(averages, function(i, a) { classDepths[a[0]] = i; }); $.each(nodes, function(i, n) { n.depth = classDepths[n.name]; }); return classDepths } /* Compute the depth of each node for layout purposes, returns the number of nodes at each depth level (for layout purposes) */ function computeRows(nodes, nodeIndex) { var selfDependencies = computeDepth(nodes, nodeIndex) if (!selfDependencies) { var classDepths = groupTasks(nodes) } var rowSizes = []; function placeNodes(n, depth) { if (rowSizes[depth] === undefined) { rowSizes[depth] = 0; } if (n.xOrder === undefined && depth === n.depth) { n.xOrder = rowSizes[depth]; rowSizes[depth]++; $.each(n.deps, function(i, dep) { if (nodeIndex[dep]) { var next_node = nodes[nodeIndex[dep]] var next_depth = (selfDependencies ? depth + 1 : classDepths[next_node.name]) placeNodes(next_node, next_depth); } }); } } placeNodes(nodes[0], 0); return rowSizes; } /* Format nodes according to their depth and horizontal sort order. Algorithm: evenly distribute nodes along each depth level, offsetting each by the text line height to prevent overlapping text. This is done within multiple columns to keep the levels from being too tall. The column width is at least nodeWidth to ensure readability. The height of each level is determined by number of nodes divided by number of columns, rounded up. */ function layoutNodes(nodes, rowSizes) { var numCols = Math.max(2, Math.floor((graphWidth - jitterWidth) / nodeWidth)); function rowStartPosition(depth) { if (depth === 0) return 20; var rowHeight = Math.ceil(rowSizes[depth-1] / numCols); return rowStartPosition(depth-1)+Math.max(rowHeight * nodeHeight + 100); } var jitter = [] for (var i in rowSizes) { jitter[i] = Math.ceil(Math.random() * jitterWidth) } $.each(nodes, function(i, node) { var numRows = Math.ceil(rowSizes[node.depth] / numCols); var levelCols = Math.ceil(rowSizes[node.depth] / numRows); var row = node.xOrder % numRows; var col = node.xOrder / numRows; node.x = ((col + 1) / (levelCols + 1)) * (graphWidth - jitterWidth - nodeWidth) + jitter[node.depth]; node.y = rowStartPosition(node.depth) + row * nodeHeight; }); } /* Parses a list of tasks to a graph format */ function createGraph(tasks, hashBase) { if (tasks.length === 0) return {nodes: [], links: []}; this.hashBase = hashBase; var nodes = $.map(tasks, nodeFromTask); var nodeIndex = uniqueIndexByProperty(nodes, "taskId"); var rowSizes = computeRows(nodes, nodeIndex); nodes = $.map(nodes, function(node) { return node.depth >= 0 ? node: null; }); layoutNodes(nodes, rowSizes); // We need to re-index nodes after filtering nodeIndex = uniqueIndexByProperty(nodes, "taskId"); var edges = createDependencyEdges(nodes, nodeIndex); return { nodes: nodes, links: edges }; } function findBounds(nodes) { var maxX = 0; var maxY = legendMaxY; $.each(nodes, function(i, node) { if (node.x>maxX) maxX = node.x; if (node.y>maxY) maxY = node.y; }); return { x:maxX, y:maxY }; } var graphWidth = window.innerWidth - 80; function DependencyGraph(containerElement) { this.svg = $(svgElement("svg")).appendTo($(containerElement)); } /* We need custom element creators for svg nodes and xlink attributes because jQuery doesn't support namespaces properly */ function svgElement(name) { return document.createElementNS("http://www.w3.org/2000/svg", name); } function svgLink(url) { var element = svgElement("a"); element.setAttributeNS("http://www.w3.org/1999/xlink", "href", url); return element; } DependencyGraph.prototype.renderGraph = function() { var self = this; $.each(this.graph.links, function(i, link) { var line = $(svgElement("line")) .attr("class","link") .attr("x1", link.source.x) .attr("y1", link.source.y) .attr("x2", link.target.x) .attr("y2", link.target.y) .appendTo(self.svg); }); $.each(this.graph.nodes, function(i, node) { var g = $(svgElement("g")) .addClass("node") .attr("transform", "translate(" + node.x + "," + node.y +")") .appendTo(self.svg); $(svgElement("circle")) .addClass("nodeCircle") .attr("r", 7) .attr("fill", statusColors[node.status]) .appendTo(g); $(svgLink(node.trackingUrl)) .append( $(svgElement("text")) .text(escapeHtml(node.name)) .attr("y", 3)) .attr("class","graph-node-a") .attr("data-task-status", node.status) .attr("data-task-id", node.taskId) .appendTo(g); var titleText = node.name; var content = $.map(node.params, function (value, name) { return escapeHtml(name + ": " + value); }).join("
    "); g.attr("title", titleText) .popover({ trigger: 'hover', container: 'body', html: true, placement: 'top', content: content }); }); // Legend for Task status var legend = $(svgElement("g")) .addClass("legend") .appendTo(self.svg); $(svgElement("rect")) .attr("x", -1) .attr("y", -1) .attr("width", legendWidth + "px") .attr("height", legendMaxY + "px") .attr("fill", "#FFF") .attr("stroke", "#DDD") .appendTo(legend); var x = 0; $.each(statusColors, function(key, color) { var c = $(svgElement("circle")) .addClass("nodeCircle") .attr("r", 7) .attr("cx", legendLineHeight) .attr("cy", (legendLineHeight-4)+(x*legendLineHeight)) .attr("fill", color) .appendTo(legend); $(svgElement("text")) .text(escapeHtml(key.charAt(0).toUpperCase() + key.substring(1).toLowerCase().replace(/_./gi, function (x) { return " " + x[1].toUpperCase(); }))) .attr("x", legendLineHeight + 14) .attr("y", legendLineHeight+(x*legendLineHeight)) .appendTo(legend); x++; }); }; DependencyGraph.prototype.updateData = function(taskList, hashBase) { $('.popover').popover('destroy'); this.graph = createGraph(taskList, hashBase); bounds = findBounds(this.graph.nodes); this.renderGraph(); this.svg.attr("height", bounds.y+10); this.svg.attr("width", graphWidth+10); this.svg[0].setAttributeNS("http://www.w3.org/2000/svg", "preserveAspectRatio", "xMidYMid meet"); this.svg[0].setAttributeNS("http://www.w3.org/2000/svg", "viewBox", "0 0 " + graphWidth + " " + (bounds.y+10)); }; return { DependencyGraph: DependencyGraph, testableMethods: { nodeFromTask: nodeFromTask, uniqueIndexByProperty: uniqueIndexByProperty, createDependencyEdges: createDependencyEdges, computeDepth: computeDepth, computeRows: computeRows, createGraph: createGraph, findBounds: findBounds } }; })(); ================================================ FILE: luigi/static/visualiser/js/luigi.js ================================================ var LuigiAPI = (function() { function LuigiAPI (urlRoot) { this.urlRoot = urlRoot; } function flatten(response, rootId) { var flattened = []; // Make the requested taskId the first in the list if (rootId && response[rootId]) { var rootNode = response[rootId]; rootNode.taskId=rootId; flattened.push(rootNode); delete response[rootId]; } $.each(response, function(key, value) { value.taskId = key; flattened.push(value); }); return flattened; } function flatten_running(response) { $.each(response, function(key, value) { value.running = flatten(value.running); }); return response; } function jsonRPC(url, paramObject, callback) { return $.ajax(url, { data: {data: JSON.stringify(paramObject)}, method: "GET", success: callback, dataType: "json" }); } function searchTerm() { // FIXME : leaky API. This shouldn't rely on the DOM. if ($('#serverSideCheckbox')[0].checked) { return $('#taskTable_filter').find('input').val(); } else { return ''; } } LuigiAPI.prototype.getDependencyGraph = function (taskId, callback, include_done) { return jsonRPC(this.urlRoot + "/dep_graph", {task_id: taskId, include_done: include_done}, function(response) { callback(flatten(response.response, taskId)); }); }; LuigiAPI.prototype.getInverseDependencyGraph = function (taskId, callback, include_done) { return jsonRPC(this.urlRoot + "/inverse_dep_graph", {task_id: taskId, include_done: include_done}, function(response) { callback(flatten(response.response, taskId)); }); }; LuigiAPI.prototype.forgiveFailures = function (taskId, callback) { return jsonRPC(this.urlRoot + "/forgive_failures", {task_id: taskId}, function(response) { callback(flatten(response.response)); }); }; LuigiAPI.prototype.markAsDone = function (taskId, callback) { return jsonRPC(this.urlRoot + "/mark_as_done", {task_id: taskId}, function(response) { callback(flatten(response.response)); }); }; LuigiAPI.prototype.getFailedTaskList = function(callback) { return jsonRPC(this.urlRoot + "/task_list", {status: "FAILED", upstream_status: "", search: searchTerm()}, function(response) { callback(flatten(response.response)); }); }; LuigiAPI.prototype.getUpstreamFailedTaskList = function(callback) { return jsonRPC(this.urlRoot + "/task_list", {status: "PENDING", upstream_status: "UPSTREAM_FAILED", search: searchTerm()}, function(response) { callback(flatten(response.response)); }); }; LuigiAPI.prototype.getDoneTaskList = function(callback) { return jsonRPC(this.urlRoot + "/task_list", {status: "DONE", upstream_status: "", search: searchTerm()}, function(response) { callback(flatten(response.response)); }); }; LuigiAPI.prototype.reEnable = function(taskId, callback) { return jsonRPC(this.urlRoot + "/re_enable_task", {task_id: taskId}, function(response) { callback(response.response); }); }; LuigiAPI.prototype.getErrorTrace = function(taskId, callback) { return jsonRPC(this.urlRoot + "/fetch_error", {task_id: taskId}, function(response) { callback(response.response); }); }; LuigiAPI.prototype.getTaskStatusMessage = function(taskId, callback) { return jsonRPC(this.urlRoot + "/get_task_status_message", {task_id: taskId}, function(response) { callback(response.response); }); }; LuigiAPI.prototype.getTaskProgressPercentage = function(taskId, callback) { return jsonRPC(this.urlRoot + "/get_task_progress_percentage", {task_id: taskId}, function(response) { callback(response.response); }); }; LuigiAPI.prototype.getRunningTaskList = function(callback) { return jsonRPC(this.urlRoot + "/task_list", {status: "RUNNING", upstream_status: "", search: searchTerm()}, function(response) { callback(flatten(response.response)); }); }; LuigiAPI.prototype.getBatchRunningTaskList = function(callback) { return jsonRPC(this.urlRoot + "/task_list", {status: "BATCH_RUNNING", upstream_status: "", search: searchTerm()}, function(response) { callback(flatten(response.response)); }); }; LuigiAPI.prototype.getPendingTaskList = function(callback) { return jsonRPC(this.urlRoot + "/task_list", {status: "PENDING", upstream_status: "", search: searchTerm()}, function(response) { callback(flatten(response.response)); }); }; LuigiAPI.prototype.getDisabledTaskList = function(callback) { jsonRPC(this.urlRoot + "/task_list", {status: "DISABLED", upstream_status: "", search: searchTerm()}, function(response) { callback(flatten(response.response)); }); }; LuigiAPI.prototype.getUpstreamDisabledTaskList = function(callback) { jsonRPC(this.urlRoot + "/task_list", {status: "PENDING", upstream_status: "UPSTREAM_DISABLED", search: searchTerm()}, function(response) { callback(flatten(response.response)); }); }; LuigiAPI.prototype.getWorkerList = function(callback) { jsonRPC(this.urlRoot + "/worker_list", {}, function(response) { callback(flatten_running(response.response)); }); }; LuigiAPI.prototype.getResourceList = function(callback) { jsonRPC(this.urlRoot + "/resource_list", {}, function(response) { callback(flatten_running(response.response)); }); }; LuigiAPI.prototype.disableWorker = function(workerId) { jsonRPC(this.urlRoot + "/disable_worker", {'worker': workerId}); }; LuigiAPI.prototype.setWorkerProcesses = function(workerId, n, callback) { var data = {worker: workerId, n: n}; jsonRPC(this.urlRoot + "/set_worker_processes", data, function(response) { callback(); }); }; LuigiAPI.prototype.sendSchedulerMessage = function(workerId, taskId, content, callback) { var data = {worker: workerId, task: taskId, content: content}; jsonRPC(this.urlRoot + "/send_scheduler_message", data, function(response) { if (callback) { callback(response.response.message_id); } }); }; LuigiAPI.prototype.getSchedulerMessageResponse = function(taskId, messageId, callback) { var data = {task_id: taskId, message_id: messageId}; jsonRPC(this.urlRoot + "/get_scheduler_message_response", data, function(response) { callback(response.response.response); }); }; LuigiAPI.prototype.isPauseEnabled = function(callback) { jsonRPC(this.urlRoot + '/is_pause_enabled', {}, function(response) { callback(response.response.enabled); }); }; LuigiAPI.prototype.hasTaskHistory = function(callback) { jsonRPC(this.urlRoot + '/has_task_history', {}, function(response) { callback(response.response); }); }; LuigiAPI.prototype.pause = function() { jsonRPC(this.urlRoot + '/pause'); }; LuigiAPI.prototype.unpause = function() { jsonRPC(this.urlRoot + '/unpause'); }; LuigiAPI.prototype.isPaused = function(callback) { jsonRPC(this.urlRoot + "/is_paused", {}, function(response) { callback(!response.response.paused); }); }; LuigiAPI.prototype.updateResource = function(resource, n, callback) { var data = {'resource': resource, 'amount': n}; jsonRPC(this.urlRoot + "/update_resource", data, function(response) { callback(); }); }; return LuigiAPI; })(); ================================================ FILE: luigi/static/visualiser/js/test/graph_test.js ================================================ module("graph.js"); test("nodeFromTask", function() { var task = { deps: ["B1","C1"], taskId: "A1", status: "DONE", name: "A", params: {}, priority: 0, }; var expected = { taskId: "A1", status: "DONE", trackingUrl: "#A1", deps: ["B1","C1"], depth: -1, name: "A", params: {}, priority: 0, }; let graph = { hashBase: "#" } deepEqual(Graph.testableMethods.nodeFromTask.bind(graph)(task), expected); }); test("uniqueIndexByProperty", function() { var input = [ {a:"x", b:100}, {a:"y", b:101}, {a:"z", b:102} ]; var expected = { "x": 0, "y": 1, "z": 2 }; deepEqual(Graph.testableMethods.uniqueIndexByProperty(input, "a"), expected); }); test("createDependencyEdges", function() { var A = {taskId: "A", deps: ["B","C"]}; var B = {taskId: "B", deps: ["D"]}; var C = {taskId: "C", deps: []}; var D = {taskId: "D", deps: []}; var nodes = [A,B,C,D]; var nodeIndex = {"A":0, "B":1, "C":2, "D":3}; var edges = Graph.testableMethods.createDependencyEdges(nodes, nodeIndex); var expected = [ {source: A, target: B}, {source: A, target: C}, {source: B, target: D} ]; deepEqual(edges, expected); }); test("computeDepth", function() { var A = {taskId: "A", deps: ["B","C"], depth:-1}; var B = {taskId: "B", deps: ["D"], depth:-1}; var C = {taskId: "C", deps: [], depth:-1}; var D = {taskId: "D", deps: [], depth:-1}; var E = {taskId: "C", deps: [], depth:-1}; var nodes = [A,B,C,D,E]; var nodeIndex = {"A":0, "B":1, "C":2, "D":3}; Graph.testableMethods.computeDepth(nodes, nodeIndex); equal(A.depth, 0); equal(B.depth, 1); equal(C.depth, 1); equal(D.depth, 2); equal(E.depth, -1); }); test("computeRowsSelfDeps", function () { var A1 = {name: "A", taskId: "A1", deps: ["A2"], depth: -1} var A2 = {name: "A", taskId: "A2", deps: [], depth: -1} var nodes = [A1, A2] var nodeIndex = {"A1": 0, "A2": 1} var rowSizes = Graph.testableMethods.computeRows(nodes, nodeIndex) equal(A1.depth, 0) equal(A2.depth, 1) deepEqual(rowSizes, [1, 1]) }); test("computeRowsGrouped", function() { var A0 = {name: "A", taskId: "A0", deps: ["D0", "B0"], depth: -1} var B0 = {name: "B", taskId: "B0", deps: ["C1", "C2"], depth: -1} var C1 = {name: "C", taskId: "C1", deps: ["D1", "E1"], depth: -1} var C2 = {name: "C", taskId: "C2", deps: ["D2", "E2"], depth: -1} var D0 = {name: "D", taskId: "D0", deps: [], depth: -1} var D1 = {name: "D", taskId: "D1", deps: [], depth: -1} var D2 = {name: "D", taskId: "D2", deps: [], depth: -1} var E1 = {name: "E", taskId: "E1", deps: [], depth: -1} var E2 = {name: "E", taskId: "E2", deps: [], depth: -1} var nodes = [A0, B0, C1, C2, D0, D1, D2, E1, E2] var nodeIndex = {"A0": 0, "B0": 1, "C1": 2, "C2": 3, "D0": 4, "D1": 5, "D2": 6, "E1": 7, "E2": 8} var rowSizes = Graph.testableMethods.computeRows(nodes, nodeIndex) equal(A0.depth, 0) equal(B0.depth, 1) equal(C1.depth, 2) equal(C2.depth, 2) equal(D0.depth, 3) equal(D1.depth, 3) equal(D2.depth, 3) equal(E1.depth, 4) equal(E2.depth, 4) deepEqual(rowSizes, [1, 1, 2, 3, 2]) }); test("createGraph", function() { var tasks = [ {taskId: "A", deps: ["B","C"], status: "PENDING"}, {taskId: "B", deps: ["D"], status: "RUNNING"}, {taskId: "C", deps: [], status: "DONE"}, {taskId: "D", deps: [], status: "DONE"}, {taskId: "E", deps: [], status: "DONE"} ]; var graph = Graph.testableMethods.createGraph(tasks); equal(graph.nodes.length, 4); equal(graph.links.length, 3); $.each(graph.nodes, function() { notEqual(this.x, 0); notEqual(this.y, 0); }); // TODO: more assertions }); ================================================ FILE: luigi/static/visualiser/js/tipsy.js ================================================ // tipsy, facebook style tooltips for jquery // version 1.0.0a // (c) 2008-2010 jason frame [jason@onehackoranother.com] // released under the MIT license (function($) { function maybeCall(thing, ctx) { return (typeof thing == 'function') ? (thing.call(ctx)) : thing; } function Tipsy(element, options) { this.$element = $(element); this.options = options; this.enabled = true; this.fixTitle(); } Tipsy.prototype = { show: function() { var title = this.getTitle(); if (title && this.enabled) { var $tip = this.tip(); $tip.find('.tipsy-inner')[this.options.html ? 'html' : 'text'](title); $tip[0].className = 'tipsy'; // reset classname in case of dynamic gravity $tip.remove().css({top: 0, left: 0, visibility: 'hidden', display: 'block'}).prependTo(document.body); var pos = $.extend({}, this.$element.offset(), { width: this.$element[0].offsetWidth || 0, height: this.$element[0].offsetHeight || 0 }); if (typeof this.$element[0].nearestViewportElement == 'object') { // SVG var el = this.$element[0]; var rect = el.getBoundingClientRect(); pos.width = rect.width; pos.height = rect.height; } var actualWidth = $tip[0].offsetWidth, actualHeight = $tip[0].offsetHeight, gravity = maybeCall(this.options.gravity, this.$element[0]); var tp; switch (gravity.charAt(0)) { case 'n': tp = {top: pos.top + pos.height + this.options.offset, left: pos.left + pos.width / 2 - actualWidth / 2}; break; case 's': tp = {top: pos.top - actualHeight - this.options.offset, left: pos.left + pos.width / 2 - actualWidth / 2}; break; case 'e': tp = {top: pos.top + pos.height / 2 - actualHeight / 2, left: pos.left - actualWidth - this.options.offset}; break; case 'w': tp = {top: pos.top + pos.height / 2 - actualHeight / 2, left: pos.left + pos.width + this.options.offset}; break; } if (gravity.length == 2) { if (gravity.charAt(1) == 'w') { tp.left = pos.left + pos.width / 2 - 15; } else { tp.left = pos.left + pos.width / 2 - actualWidth + 15; } } $tip.css(tp).addClass('tipsy-' + gravity); $tip.find('.tipsy-arrow')[0].className = 'tipsy-arrow tipsy-arrow-' + gravity.charAt(0); if (this.options.className) { $tip.addClass(maybeCall(this.options.className, this.$element[0])); } if (this.options.fade) { $tip.stop().css({opacity: 0, display: 'block', visibility: 'visible'}).animate({opacity: this.options.opacity}); } else { $tip.css({visibility: 'visible', opacity: this.options.opacity}); } var t = this; var set_hovered = function(set_hover){ return function(){ t.$tip.stop(); t.tipHovered = set_hover; if (!set_hover){ if (t.options.delayOut === 0) { t.hide(); } else { setTimeout(function() { if (t.hoverState == 'out') t.hide(); }, t.options.delayOut); } } }; }; $tip.hover(set_hovered(true), set_hovered(false)); } }, hide: function() { if (this.options.fade) { this.tip().stop().fadeOut(function() { $(this).remove(); }); } else { this.tip().remove(); } }, fixTitle: function() { var $e = this.$element; if ($e.attr('title') || typeof($e.attr('original-title')) != 'string') { $e.attr('original-title', $e.attr('title') || '').removeAttr('title'); } if (typeof $e.context.nearestViewportElement == 'object'){ if ($e.children('title').length){ $e.append('' + ($e.children('title').text() || '') + '') .children('title').remove(); } } }, getTitle: function() { var title, $e = this.$element, o = this.options; this.fixTitle(); if (typeof o.title == 'string') { var title_name = o.title == 'title' ? 'original-title' : o.title; if ($e.children(title_name).length){ title = $e.children(title_name).html(); } else{ title = $e.attr(title_name); } } else if (typeof o.title == 'function') { title = o.title.call($e[0]); } title = ('' + title).replace(/(^\s*|\s*$)/, ""); return title || o.fallback; }, tip: function() { if (!this.$tip) { this.$tip = $('
    ').html('
    '); } return this.$tip; }, validate: function() { if (!this.$element[0].parentNode) { this.hide(); this.$element = null; this.options = null; } }, enable: function() { this.enabled = true; }, disable: function() { this.enabled = false; }, toggleEnabled: function() { this.enabled = !this.enabled; } }; $.fn.tipsy = function(options) { if (options === true) { return this.data('tipsy'); } else if (typeof options == 'string') { var tipsy = this.data('tipsy'); if (tipsy) tipsy[options](); return this; } options = $.extend({}, $.fn.tipsy.defaults, options); if (options.hoverlock && options.delayOut === 0) { options.delayOut = 100; } function get(ele) { var tipsy = $.data(ele, 'tipsy'); if (!tipsy) { tipsy = new Tipsy(ele, $.fn.tipsy.elementOptions(ele, options)); $.data(ele, 'tipsy', tipsy); } return tipsy; } function enter() { var tipsy = get(this); tipsy.hoverState = 'in'; if (options.delayIn === 0) { tipsy.show(); } else { tipsy.fixTitle(); setTimeout(function() { if (tipsy.hoverState == 'in') tipsy.show(); }, options.delayIn); } } function leave() { var tipsy = get(this); tipsy.hoverState = 'out'; if (options.delayOut === 0) { tipsy.hide(); } else { var to = function() { if (!tipsy.tipHovered || !options.hoverlock){ if (tipsy.hoverState == 'out') tipsy.hide(); } }; setTimeout(to, options.delayOut); } } if (options.trigger != 'manual') { var binder = options.live ? 'live' : 'bind', eventIn = options.trigger == 'hover' ? 'mouseenter' : 'focus', eventOut = options.trigger == 'hover' ? 'mouseleave' : 'blur'; this[binder](eventIn, enter)[binder](eventOut, leave); } return this; }; $.fn.tipsy.defaults = { className: null, delayIn: 0, delayOut: 0, fade: false, fallback: '', gravity: 'n', html: false, live: false, offset: 0, opacity: 0.8, title: 'title', trigger: 'hover', hoverlock: false }; // Overwrite this method to provide options on a per-element basis. // For example, you could store the gravity in a 'tipsy-gravity' attribute: // return $.extend({}, options, {gravity: $(ele).attr('tipsy-gravity') || 'n' }); // (remember - do not modify 'options' in place!) $.fn.tipsy.elementOptions = function(ele, options) { return $.metadata ? $.extend({}, options, $(ele).metadata()) : options; }; $.fn.tipsy.autoNS = function() { return $(this).offset().top > ($(document).scrollTop() + $(window).height() / 2) ? 's' : 'n'; }; $.fn.tipsy.autoWE = function() { return $(this).offset().left > ($(document).scrollLeft() + $(window).width() / 2) ? 'e' : 'w'; }; /** * yields a closure of the supplied parameters, producing a function that takes * no arguments and is suitable for use as an autogravity function like so: * * @param margin (int) - distance from the viewable region edge that an * element should be before setting its tooltip's gravity to be away * from that edge. * @param prefer (string, e.g. 'n', 'sw', 'w') - the direction to prefer * if there are no viewable region edges effecting the tooltip's * gravity. It will try to vary from this minimally, for example, * if 'sw' is preferred and an element is near the right viewable * region edge, but not the top edge, it will set the gravity for * that element's tooltip to be 'se', preserving the southern * component. */ $.fn.tipsy.autoBounds = function(margin, prefer) { return function() { var dir = {ns: prefer[0], ew: (prefer.length > 1 ? prefer[1] : false)}, boundTop = $(document).scrollTop() + margin, boundLeft = $(document).scrollLeft() + margin, $this = $(this); if ($this.offset().top < boundTop) dir.ns = 'n'; if ($this.offset().left < boundLeft) dir.ew = 'w'; if ($(window).width() + $(document).scrollLeft() - $this.offset().left < margin) dir.ew = 'e'; if ($(window).height() + $(document).scrollTop() - $this.offset().top < margin) dir.ns = 's'; return dir.ns + (dir.ew ? dir.ew : ''); }; }; })(jQuery); ================================================ FILE: luigi/static/visualiser/js/util.js ================================================ function escapeHtml(unsafe) { return unsafe .replace(/&/g, "&") .replace(//g, ">") .replace(/"/g, """) .replace(/'/g, "'"); } ================================================ FILE: luigi/static/visualiser/js/visualiserApp.js ================================================ function visualiserApp(luigi) { var templates = {}; var typingTimer = 0; var dt; // DataTable instantiated in $(document).ready() var missingCategories = {}; var currentFilter = { taskFamily: "", taskCategory: [], tableFilter: "" }; var taskIcons = { PENDING: 'pause', RUNNING: 'play', BATCH_RUNNING: 'play', DONE: 'check', FAILED: 'times', UPSTREAM_FAILED: 'warning', DISABLED: 'minus-circle', UPSTREAM_DISABLED: 'warning' }; var VISTYPE_DEFAULT = 'svg'; /* * Updates view of the Visualization type. */ function updateVisType(newVisType) { $('#toggleVisButtons label').removeClass('active'); var visTypeInput = $('#toggleVisButtons input[value="' + newVisType + '"]'); visTypeInput.parent().addClass('active'); visTypeInput.prop('checked', true); } function loadTemplates() { $("script[type='text/template']").each(function(i, element) { var name = $(element).attr("name"); var content = $(element).text(); templates[name] = content; }); } function renderTemplate(templateName, dataObject) { return $("
    ").html(Mustache.render(templates[templateName], dataObject)); } function formatTime(dateObject) { return dateObject.getHours() + ":" + dateObject.getMinutes() + ":" + dateObject.getSeconds(); } function taskToDisplayTask(task) { var taskName = task.name; var taskParams = JSON.stringify(task.params); var displayTime = new Date(Math.floor(task.last_updated*1000)).toLocaleString(); var time_running = -1; if (task.status == "RUNNING" && "time_running" in task) { var current_time = new Date().getTime(); var minutes_running = Math.round((current_time - task.time_running * 1000) / 1000 / 60); time_running = task.time_running; displayTime += " | " + minutes_running + " minutes"; } return { taskId: task.taskId, encodedTaskId: encodeURIComponent(task.taskId), taskName: taskName, taskParams: taskParams, displayName: task.display_name, priority: task.priority, resources: JSON.stringify(task.resources_running || task.resources).replace(/,"/g, ', "'), displayTime: displayTime, displayTimestamp: task.last_updated, timeRunning: time_running, trackingUrl: task.tracking_url, status: task.status, graph: (task.status == "PENDING" || task.status == "RUNNING" || task.status == "DONE"), error: task.status == "FAILED", re_enable: task.status == "DISABLED" && task.re_enable_able, mark_as_done: (task.status == "RUNNING" || task.status == "FAILED" || task.status == "DISABLED"), statusMessage: task.status_message, progressPercentage: task.progress_percentage, acceptsMessages: task.accepts_messages, workerIdRunning: task.worker_running, }; } function taskCategoryIcon(category) { var iconClass; var iconColor; switch (category) { case 'PENDING': iconClass = 'fa-pause'; iconColor = 'yellow'; break; case 'RUNNING': iconClass = 'fa-play'; iconColor = 'aqua'; break; case 'BATCH_RUNNING': iconClass = 'fa-play'; iconColor = 'purple'; break; case 'DONE': iconClass = 'fa-check'; iconColor = 'green'; break; case 'FAILED': iconClass = 'fa-times'; iconColor = 'red'; break; case 'DISABLED': iconClass = 'fa-minus-circle'; iconColor = 'gray'; break; case 'UPSTREAM_FAILED': iconClass = 'fa-warning'; iconColor = 'maroon'; break; case 'UPSTREAM_DISABLED': iconClass = 'fa-warning'; iconColor = 'gray'; break; default: iconClass = 'fa-bug'; iconColor = 'orange'; break; } return ''; } /** * Filter table by all activated info boxes. */ function filterByCategory(dt, activeBoxes) { if (activeBoxes === undefined) { activeBoxes = getActiveBoxes(); } currentFilter.taskCategory = activeBoxes; dt.column(0).search(categoryQuery(activeBoxes), regex=true).draw(); } function categoryQuery(activeBoxes) { // Searched content will be . return '\\b(' + activeBoxes.join('|') + ')\\b'; } function getActiveBoxes() { var infoBoxes = $('.info-box'); var activeBoxes = []; infoBoxes.each(function (i) { if (infoBoxes[i].dataset.on === 'yes') { activeBoxes.push(infoBoxes[i].dataset.category); } }); return activeBoxes; } function filterByTaskFamily(taskFamily, dt) { currentFilter.taskFamily = taskFamily; if (taskFamily === "") { dt.column(1).search('').draw(); } else { dt.column(1).search('^' + taskFamily + '$', regex = true).draw(); } } function toggleInfoBox(infoBox, activate) { var infoBoxColor = infoBox.dataset.color; var infoBoxIcon = $(infoBox).find('.info-box-icon'); var colorClass = 'bg-' + infoBoxColor; if ((infoBox.dataset.on === undefined) || (infoBox.dataset.on === 'no') || activate) { infoBox.dataset.on = 'yes'; infoBoxIcon.removeClass(colorClass); $(infoBox).addClass(colorClass); } else { infoBox.dataset.on = 'no'; $(infoBox).removeClass(colorClass); infoBoxIcon.addClass(colorClass); } } function renderSidebar(tasks) { // tasks is a list of task names var counts = {}; $.each(tasks, function(i) { var name = tasks[i]; if (counts[name] === undefined) { counts[name] = 0; } counts[name] += 1; }); var taskList = []; $.each(counts, function (name) { var dotIndex = name.indexOf('.'); var prefix = 'Others'; if (dotIndex > 0) { prefix = name.slice(0, dotIndex); } var prefixList = taskList.find(function (pref) { return pref.name == prefix; }) if (prefixList) { prefixList.tasks.push({name: name, count: counts[name]}); } else { prefixList = { name: prefix, tasks: [{name: name, count: counts[name]}] } taskList.push(prefixList); } }); taskList.sort(function(a,b){ if (a.name == 'Others') { if (b.name == 'Others') { return 0; } return 1; } else if (b.name == 'Others') { return -1; } return a.name.localeCompare(b.name); }); taskList.forEach(function(p){ p.tasks.sort(function(a,b){ return a.name.localeCompare(b.name); }); }); return renderTemplate("sidebarTemplate", {"tasks": taskList}); } function selectSidebarItem(item) { var sidebarItems = $('.sidebar').find('li'); sidebarItems.each(function (i) { var item2 = sidebarItems[i]; if (item2.dataset.task === undefined) { return; } if (item === item2) { if ($(item2).hasClass('active')) { // item is active, deselect $(item2).removeClass('active'); $(item2).find('.badge').removeClass('bg-green'); } else { // select item $(item2).addClass('active'); $(item2).find('.badge').addClass('bg-green'); } } else { // clear any selection $(item2).removeClass('active'); $(item2).find('.badge').removeClass('bg-green'); } }); } function renderWarnings() { return renderTemplate( "warningsTemplate", {missingCategories: $.map(missingCategories, function (v, k) {return v;})} ); } function processWorker(worker) { worker.encoded_first_task = encodeURIComponent(worker.first_task); worker.tasks = worker.running.map(taskToDisplayTask); worker.tasks.sort(function(task1, task2) { return task1.timeRunning - task2.timeRunning; }); worker.start_time = new Date(worker.started * 1000).toLocaleString(); worker.active = new Date(worker.last_active * 1000).toLocaleString(); worker.is_disabled = worker.state === 'disabled'; return worker; } function renderWorkers(workers) { return renderTemplate("workerTemplate", {"workerList": workers.map(processWorker)}); } function processResource(resource) { resource.tasks = resource.running.map(taskToDisplayTask); resource.percent_used = 100 * resource.num_used / resource.num_total; if (resource.percent_used >= 100) { resource.bar_type = 'danger'; resource.percent_used = 100; } else if (resource.percent_used > 50) { resource.bar_type = 'warning'; } else { resource.bar_type = 'success'; } return resource; } function renderResources(resources) { return renderTemplate("resourceTemplate", { "resources": resources.map(processResource).sort(function(r1, r2) { if (r1.percent_used > r2.percent_used) return -1; else if (r1.percent_used < r2.percent_used) return 1; else if (r1.num_used > r2.num_used) return -1; else if (r1.num_used < r2.num_used) return 1; else if (r1.name < r2.name) return -1; else if (r1.name > r2.name) return 1; else return 0; }) }); } function switchTab(tabId) { $(".tabButton").parent().removeClass("active"); $(".tab-pane").removeClass("active"); $("#" + tabId).addClass("active"); $(".navbar-nav li").removeClass("active"); $(".js-nav-link[data-tab=" + tabId + "]").parent().addClass("active"); updateSidebar(tabId); } function showErrorTrace(data) { data.error = decodeError(data.error); if (data.taskParams) { data.taskParams = Object.entries(data.taskParams).map(([k,v]) => `--${k.replace(/_/g, '-')} ${JSON.stringify(v)}`).join(" "); } $("#errorModal").empty().append(renderTemplate("errorTemplate", data)); $("#errorModal").modal({}); } function showStatusMessage(data) { $("#statusMessageModal").empty().append(renderTemplate("statusMessageTemplate", data)); $("#statusMessageModal").modal({}); var refreshInterval = setInterval(function() { if ($("#statusMessageModal").is(":hidden")) clearInterval(refreshInterval); else { luigi.getTaskStatusMessage(data.taskId, function(data) { if (data.statusMessage === null) $("#statusMessageModal pre").hide(); else { $("#statusMessageModal pre").html(data.statusMessage).show(); } }); luigi.getTaskProgressPercentage(data.taskId, function(data) { // show or hide the progress bar container in the message modal $("#statusMessageModal .progress").toggle(data.progressPercentage !== null); // adjust the status of both progress bars (message modal and worker list) var value = data.progressPercentage || 0; var progressBars = $('#statusMessageModal .progress-bar, ' + '.worker-table tbody .taskProgressBar[data-task-id="' + data.taskId + '"]'); progressBars.attr('aria-valuenow', value) .text(value + '%') .css({'width': value + '%'}); }); } }, 500 ); } function showSchedulerMessageModal(data) { var $modal = $("#schedulerMessageModal"); $modal.empty().append(renderTemplate("schedulerMessageTemplate", data)); var $input = $modal.find("#schedulerMessageInput"); var $send = $modal.find("#schedulerMessageButton"); var $awaitResponse = $modal.find("#schedulerMessageAwaitResponse"); var $responseContainer = $modal.find("#schedulerMessageResponse"); var $responseSpinner = $responseContainer.find("pre > i"); var $responseContent = $responseContainer.find("pre > div"); $input.on("keypress", function($event) { if (event.keyCode == 13) { $send.trigger("click"); $event.preventDefault(); } }); $send.on("click", function($event) { var content = $input.val(); var awaitResponse = $awaitResponse.prop("checked"); if (content && data.worker) { if (awaitResponse) { $responseContainer.show(); $responseSpinner.show(); $responseContent.empty(); luigi.sendSchedulerMessage(data.worker, data.taskId, content, function(messageId) { var interval = window.setInterval(function() { luigi.getSchedulerMessageResponse(data.taskId, messageId, function(response) { if (response != null) { clearInterval(interval); $responseSpinner.hide(); $responseContent.html(response); } }); }, 1000); }); $event.stopPropagation(); } else { $responseContainer.hide(); luigi.sendSchedulerMessage(data.worker, data.taskId, content); } } }); $modal.on("shown.bs.modal", function() { $input.focus(); }); $modal.modal({}); } function preProcessGraph(dependencyGraph) { var extraNodes = []; var seen = {}; $.each(dependencyGraph, function(i, node) { seen[node.taskId] = true; }); $.each(dependencyGraph, function(i, node) { $.each(node.deps, function(j, dep) { if (!seen[dep]) { seen[dep] = true; var paramsStrs = (/\((.*)\)/.exec(dep) || ['', ''])[1].split(', '); var params = {}; $.each(paramsStrs, function(i, param) { if (param !== "") { var kv = param.split('='); params[kv[0]] = kv[1]; } }); extraNodes.push({ name: (/(\w+)\(/.exec(dep) || [])[1], taskId: dep, deps: [], params: params, status: "TRUNCATED" }); } }); }); return dependencyGraph.concat(extraNodes); } function makeGraphCallback(visType, taskId, paint) { function depGraphCallbackD3(dependencyGraph) { $("#searchError").empty(); $("#searchError").removeClass(); if(dependencyGraph.length > 0) { $("#dependencyTitle").text(dependencyGraph[0].display_name); if(dependencyGraph != '{}'){ for (var id in dependencyGraph) { if (dependencyGraph[id].deps.length > 0) { //console.log(asingInput(dependencyGraph, id)); dependencyGraph[id].inputQueue = asingInput(dependencyGraph, id); dependencyGraph[id].inputThroughput = 50; dependencyGraph[id].count = 5; dependencyGraph[id].consumers = 1; }else{ dependencyGraph[id].inputThroughput = 50; dependencyGraph[id].count = 5; dependencyGraph[id].consumers = 1; } } } } else { $("#searchError").addClass("alert alert-error"); $("#searchError").text("Couldn't find task " + taskId); } drawGraphETL(dependencyGraph, paint); bindGraphEvents(); } function depGraphCallback (dependencyGraph) { $("#graphPlaceholder svg").empty(); $("#searchError").empty(); $("#searchError").removeClass(); if(dependencyGraph.length > 0) { $("#dependencyTitle").text(dependencyGraph[0].display_name); var hashBaseObj = URI.parseQuery(location.hash.replace('#', '')); delete hashBaseObj.taskId; var hashBase = '#' + URI.buildQuery(hashBaseObj) + '&taskId='; $("#graphPlaceholder").get(0).graph.updateData(dependencyGraph, hashBase); $("#graphContainer").show(); bindGraphEvents(); } else { $("#searchError").addClass("alert alert-error"); $("#searchError").text("Couldn't find task " + taskId); } } function processedCallback(callback) { function processed(dependencyGraph) { return callback(preProcessGraph(dependencyGraph)); } return processed; } if (visType == 'd3') { return processedCallback(depGraphCallbackD3); } else { return processedCallback(depGraphCallback); } } function processHashChange(paint) { var hash = decodeURIComponent(location.hash); // Convert fragment params to object. var fragmentQuery = URI.parseQuery(location.hash.replace('#', '')); // "http://example.org/#!/foo/bar/baz.html"); if (fragmentQuery.tab == "workers") { switchTab("workerList"); } else if (fragmentQuery.tab == "resources") { expandResources(fragmentQuery.resources); switchTab("resourceList"); } else if (fragmentQuery.tab == "graph") { var taskId = fragmentQuery.taskId; var hideDone = fragmentQuery.hideDone === '1' ? true : false; // Populate fields with values from hash. $('#hideDoneCheckbox').prop('checked', hideDone); $("#invertCheckbox").prop('checked', fragmentQuery.invert === '1' ? true : false); $("#js-task-id").val(fragmentQuery.taskId); // Empty errors. $("#searchError").empty(); $("#searchError").removeClass(); var visType = fragmentQuery.visType || VISTYPE_DEFAULT; if (taskId) { var depGraphCallback = makeGraphCallback(visType, taskId, paint); if (fragmentQuery.invert) { luigi.getInverseDependencyGraph(taskId, depGraphCallback, !hideDone); } else { luigi.getDependencyGraph(taskId, depGraphCallback, !hideDone); } } updateVisType(visType); initVisualisation(visType); switchTab("dependencyGraph"); } else { // Tasks tab. // Populate fields with values from hash. if (fragmentQuery.length) { $('select[name=taskTable_length]').val(fragmentQuery.length); } $("#serverSideCheckbox").prop('checked', fragmentQuery.filterOnServer === '1' ? true : false); dt.search(fragmentQuery.search__search); $('#familySidebar li').removeClass('active'); $('#familySidebar li .badge').removeClass('bg-green'); if (fragmentQuery.family) { family_item = $('#familySidebar li[data-task="' + fragmentQuery.family + '"]'); family_item.addClass('active'); family_item.find('.badge').addClass('bg-green'); filterByTaskFamily(fragmentQuery.family, dt); } if (fragmentQuery.statuses) { var statuses = JSON.parse(fragmentQuery.statuses); $.each(statuses, function (status) { toggleInfoBox($('#' + statuses[status] + '_info')[0], true); }); filterByCategory(dt, statuses); } if (fragmentQuery.order) { dt.order([fragmentQuery.order.split(',')]); } dt.draw(); switchTab("taskList"); } } function bindGraphEvents() { var fragmentQuery = URI.parseQuery(location.hash.replace('#', '')); var visType = fragmentQuery.visType; if (visType === 'd3') { $('.node').click(function(event) { var taskDiv = $(this).find('.taskNode'); var taskId = taskDiv.attr("data-task-id"); event.preventDefault(); // NOTE : hasClass() not reliable inside SVG if ($(this).attr('class').match(/\bFAILED\b/)) { luigi.getErrorTrace(taskId, function (error) { showErrorTrace(error); }); } else { fragmentQuery['taskId'] = taskId; window.location.href = 'index.html#' + URI.buildQuery(fragmentQuery); } }); } else { $(".graph-node-a").click(function(event) { var taskId = $(this).attr("data-task-id"); var status = $(this).attr("data-task-status"); if (status == "FAILED") { event.preventDefault(); luigi.getErrorTrace(taskId, function(error) { showErrorTrace(error); }); } }); } } function bindListEvents() { $(window).on('hashchange', processHashChange); $('#serverSideCheckbox').click(function(e) { e.preventDefault(); changeState('filterOnServer', this.checked ? '1' : null); updateTasks(); }); $("#invertCheckbox").click(function(e) { e.preventDefault(); changeState('invert', this.checked ? '1' : null); }); $('#hideDoneCheckbox').click(function(e) { // Copy checkbox value to hash. e.preventDefault(); changeState('hideDone', this.checked ? '1' : null); }); $("a[href=#list]").click(function() { location.hash=""; }); $("#loadTaskForm").submit(function(event) { event.preventDefault(); var taskId = $(this).find("input").val(); changeState('taskId', taskId.length > 0 ? taskId : null); }); $('.info-box').on('click', function () { toggleInfoBox(this); filterByCategory(dt); }); $('input[name=vis-type]').on('change', function () { changeState('visType', $(this).val()); }); /* Note: The #filter-input element is used by LuigiAPI to constrain requests to the server. When the accompanying button is pressed we force a reload. */ $('#serverSide').on('change', 'label', function () { updateTasks(); }); } function asingInput(worker, id){ if (worker[id].deps.length > 0) { //console.log(worker[id].deps); return worker[id].deps; } } function getDurations(tasks, listId){ var durations = {}; for (var i = 0; i < listId.length; i++) { for (var j = 0; j < tasks.length; j++) { if (listId[i] === tasks[j].taskId) { // The duration of the task from when it started running to when it finished. var finishTime = new Date(tasks[j].last_updated*1000); var startTime = new Date(tasks[j].time_running*1000); durations[listId[i]] = new Date(finishTime - startTime); } } } return durations; } function getParam(tasks, id){ for (var i = 0; i < tasks.length; i++) { if (tasks[i].taskId === id) { return tasks[i].worker_running; } } } function getStatusTasks(tasks){ var status; for (var i = 0; i < tasks.length; i++) { if (tasks[i].status === "DONE") { status = true; } else { return false; } } return status; } function drawGraphETL(tasks, paint){ // Set up zoom support var svg = d3.select("#mysvg"); var inner = svg.select("g"), zoom = d3.behavior.zoom().on("zoom", function() { inner.attr("transform", "translate(" + d3.event.translate + ")" + "scale(" + d3.event.scale + ")"); }); svg.call(zoom); // Create map of taskId to task var taskIdMap = {}; $.each(tasks, function (i, task) { taskIdMap[task.taskId] = task; }); var render = new dagreD3.render(); // Left-to-right layout var g = new dagreD3.graphlib.Graph(); g.setGraph({ nodesep: 70, ranksep: 50, rankdir: "LR", marginx: 20, marginy: 20, height: 400, ranker: "longest-path" }); function draw(isUpdate) { for (var id in tasks) { var task = tasks[id]; var className = task.status; var html = "
    "; html += ""; html += ""+task.name+""; html += ""+ task.status +""; html += "
    "; g.setNode(task.taskId, { labelType: "html", label: html, rx: 5, ry: 5, padding: 0, class: className }); if (task.inputQueue) { for (var i = 0; i < task.inputQueue.length; i++) { // Destination node may not be in tasks if this is an inverted graph if (taskIdMap[task.inputQueue[i]] !== undefined) { if (task.status === "DONE") { var durations = getDurations(tasks, task.inputQueue); var duration = durations[task.inputQueue[i]]; var oneDayInMilliseconds = 24 * 60 * 60 * 1000; var durationLabel; if (duration.getTime() < oneDayInMilliseconds) { // Label task duration in stripped ISO format (hh:mm:ss.f) durationLabel = duration.toISOString().substr(11, 12); } else { durationLabel = "> 24h"; } g.setEdge(task.inputQueue[i], task.taskId, { label: durationLabel, width: 40 }); } else { g.setEdge(task.inputQueue[i], task.taskId, { width: 40 }); } } } } } var styleTooltip = function(name, description) { return "

    " + name + "

    " + description + "

    "; }; inner.call(render, g); if(paint){ // Zoom and scale to fit var zoomScale = zoom.scale(); var graphWidth = g.graph().width + 80; var graphHeight = g.graph().height + 40; var width = parseInt(svg.style("width").replace(/px/, "")); var height = parseInt(svg.style("height").replace(/px/, "")); zoomScale = Math.min(width / graphWidth, height / graphHeight); var translate = [(width/2) - ((graphWidth*zoomScale)/2), (height/2) - ((graphHeight*zoomScale)/2)]; zoom.translate(translate); zoom.scale(zoomScale); zoom.event(isUpdate ? svg.transition().duration(3000) : d3.select("#mysvg")); } inner.selectAll("g.node") .attr("title", function(v) { return styleTooltip(v, getParam(tasks, v)); }) .each(function(v) { $(this).tipsy({ gravity: "w", opacity: 1, html: true }); }); } draw(); } /* DataTables functions */ // Remove tasks of a given category and add new ones. function updateTaskCategory(dt, category, tasks) { var taskMap = {}; var mostImportantCategory = function (cat1, cat2) { var priorities = [ 'RUNNING', 'BATCH_RUNNING', 'DONE', 'PENDING', 'UPSTREAM_DISABLED', 'UPSTREAM_FAILED', 'DISABLED', 'FAILED' ]; // NOTE : -1 indicates not in list var i1 = priorities.indexOf(cat1); var i2 = priorities.indexOf(cat2); var ret; if (i1 > i2) { ret = cat1; } else { ret = cat2; } return ret; }; dt.rows(function (i, data) { taskMap[data.taskId] = data.category; return data.category === category; }).remove(); var taskCount; /* Check for integers in tasks. This indicates max-shown-tasks was exceeded */ if (tasks.length === 1 && typeof(tasks[0]) === 'number') { taskCount = tasks[0] === -1 ? 'unknown' : tasks[0]; missingCategories[category] = {name: category, count: taskCount}; } else { var displayTasks = tasks.map(taskToDisplayTask); displayTasks = displayTasks.filter(function (obj) { if (obj === null) { return false; } if (category === mostImportantCategory(category, taskMap[obj.taskId])) { obj.category = category; return true; } return false; }); dt.rows.add(displayTasks); taskCount = displayTasks.length; delete missingCategories[category]; } $('#'+category+'_info').find('.info-box-number').html(taskCount); $('#'+category+'_info i.fa').removeClass().addClass('fa fa-'+taskIcons[category]); } function updateCurrentFilter() { var content; currentFilter.tableFilter = dt.search(); if ((currentFilter.tableFilter === "") && ($.isEmptyObject(currentFilter.taskCategory)) && (currentFilter.taskFamily === "")) { content = ''; } else { if (currentFilter.taskCategory !== "") { currentFilter.catNames = $.map(currentFilter.taskCategory, function (x) { return {name: x}; }); } content = renderTemplate('currentFilterTemplate', currentFilter); } $('#currentFilter').html(content); } function initVisualisation(newVisType) { // Prepare graphPlaceholder for D3 code if (newVisType == 'd3') { $('#graphPlaceholder').empty(); $('#graphPlaceholder').html('
    '); } else { $('#graphPlaceholder').empty(); var graph = new Graph.DependencyGraph($("#graphPlaceholder")[0]); $("#graphPlaceholder")[0].graph = graph; } } function updateTasks() { $('.status-info .info-box-number').text('?'); $('.status-info i.fa').removeClass().addClass('fa fa-spinner fa-pulse'); var ajax1 = luigi.getRunningTaskList(function(runningTasks) { updateTaskCategory(dt, 'RUNNING', runningTasks); }); var ajax2 = luigi.getBatchRunningTaskList(function(batchRunningTasks) { updateTaskCategory(dt, 'BATCH_RUNNING', batchRunningTasks); }); var ajax3 = luigi.getFailedTaskList(function(failedTasks) { updateTaskCategory(dt, 'FAILED', failedTasks); }); var ajax4 = luigi.getUpstreamFailedTaskList(function(upstreamFailedTasks) { updateTaskCategory(dt, 'UPSTREAM_FAILED', upstreamFailedTasks); }); var ajax5 = luigi.getDisabledTaskList(function(disabledTasks) { updateTaskCategory(dt, 'DISABLED', disabledTasks); }); var ajax6 = luigi.getUpstreamDisabledTaskList(function(upstreamDisabledTasks) { updateTaskCategory(dt, 'UPSTREAM_DISABLED', upstreamDisabledTasks); }); var ajax7 = luigi.getPendingTaskList(function(pendingTasks) { updateTaskCategory(dt, 'PENDING', pendingTasks); }); var ajax8 = luigi.getDoneTaskList(function(doneTasks) { updateTaskCategory(dt, 'DONE', doneTasks); }); $.when(ajax1, ajax2, ajax3, ajax4, ajax5, ajax6, ajax7, ajax8).done(function () { dt.draw(); $('.sidebar').html(renderSidebar(dt.column(1).data())); var selectedFamily = $('.sidebar-menu').find('li[data-task="' + currentFilter.taskFamily + '"]')[0]; selectSidebarItem(selectedFamily); if (selectedFamily) { var selectedUl = $(selectedFamily).parent(); selectedUl.show(); selectedUl.prev().addClass('expanded'); } else { var others = $('.sidebar-folder:contains(Others)') others.addClass('expanded') others.next().show() } $('.sidebar-menu').on('click', 'li:not(.sidebar-folder)', function (e) { e.stopPropagation(); if (this.dataset.task) { selectSidebarItem(this); if ($(this).hasClass('active')) { filterByTaskFamily(this.dataset.task, dt); } else { filterByTaskFamily("", dt); } } }); $('.sidebar-menu').on('click', '.sidebar-folder', function () { const ul = this.nextElementSibling; $(ul).slideToggle() this.classList.toggle('expanded') }) $('#clear-task-filter').on('click', function () { filterByTaskFamily("", dt); }); if ($.isEmptyObject(missingCategories)) { $('#warnings').html(''); } else { $('#warnings').html(renderWarnings()); } processHashChange(); }); } function updateSidebar(tabName) { if (tabName === 'taskList') { $('body').removeClass('sidebar-collapse'); } else { $('body').addClass('sidebar-collapse'); } } // Error strings may or may not be JSON encoded, depending on client version // Decoding an unencoded string may raise an exception. function decodeError(error) { var decoded; try { decoded = JSON.parse(error); } catch (e) { decoded = error; } return decoded; } /** * Return HTML of a task parameter dictionary * @param params: task parameter dictionary */ function renderParams(params) { var htmls = []; for (var key in params) { htmls.push('' + escapeHtml(key) + '=' + escapeHtml(params[key]) + ''); } return htmls.join(', '); } /** * Updates the number of worker processes of a worker * @param worker: the id of the worker * @param n: the number of processes to set */ function updateWorkerProcesses(worker, n) { n = Math.max(1, n); // the spinner is just for visual feedback var $label = $('#workerList').find('#label-n-workers[data-worker="' + worker + '"]'); $label.html(''); luigi.setWorkerProcesses(worker, n, function() { $label.text(n); }); } /** * Updates the number of units of a given resource available in the scheduler * @param resource: the name of the resource * @param n: the number of units to set the resource limit to */ function updateResourceCount(resource, n) { var progressBar = $('#' + resource + '-resource-box .progress-bar'); var used = /(\S+)\//.exec(progressBar.text())[1]; nVal = parseInt(n); if (isNaN(nVal) || nVal < 0) { return; } usedVal = parseInt(used); width = Math.floor(100 * usedVal / nVal); if (width < 0) { width = 0; } if (width > 100) { width = 100; } luigi.updateResource(resource, n, function() { progressBar.text(usedVal + '/' + nVal); progressBar.attr('style', 'width: ' + width + '%'); }); } /** * Returns the current units of a resource used * @param resource: the name of the resource */ function currentResourceCount(resource) { var progressBar = $('#' + resource + '-resource-box .progress-bar'); var count = /\/(\S+)/.exec(progressBar.text())[1]; return parseInt(count); } function changeState(key, value) { var fragmentQuery = URI.parseQuery(location.hash.replace('#', '')); if (value) { fragmentQuery[key] = value; } else { delete fragmentQuery[key]; } location.hash = '#' + URI.buildQuery(fragmentQuery); } function expandedResources() { return $('.resource-box.in').toArray().map(function (val) { return val.dataset.resource; }); } function expandResources(resources) { if (resources === undefined) { resources = []; } else { resources = JSON.parse(resources); } $('.resource-box').each(function (i, item) { if (resources.indexOf(item.dataset.resource) === -1) { $(item).collapse('hide'); } else { $(item).collapse('show'); } }); } /** * Create the pause/unpause toggle */ function createPauseToggle(checked) { var check = checked ? " checked" : ""; var html = $(''); $('#pause-form').append(html); $('#pause').bootstrapToggle({ on: 'Running', off: 'Paused', onstyle: 'success', offstyle: 'danger' }); $('#pause').change(function() { if (this.checked) { luigi.unpause(); } else { luigi.pause(); } }) } $(document).ready(function() { loadTemplates(); luigi.hasTaskHistory(function(hasTaskHistory) { if (hasTaskHistory) { $('#topNavbar').append(renderTemplate('topNavbarItem', { label: "History", href: "../../history", }).children()[0]); } }); luigi.isPauseEnabled(function(enabled) { if (enabled) { luigi.isPaused(createPauseToggle); } }); luigi.getWorkerList(function(workers) { $("#workerList").append(renderWorkers(workers)); $('.worker-table tbody').on('click', 'td .statusMessage', function() { var data = $(this).data(); showStatusMessage(data); }); $('.worker-table tbody').on('click', 'td .schedulerMessage', function() { var data = $(this).data(); showSchedulerMessageModal(data); }); }); luigi.getResourceList(function(resources) { $("#resourceList").append(renderResources(resources)); expandResources(URI.parseQuery(location.hash.replace('#', '')).resources); $('.resources-collapse').click(function (e) { e.preventDefault(); var collapse_block = $(this.dataset.target); if (collapse_block.hasClass('collapsing')) { return; } var resource = collapse_block.attr('data-resource'); var resourceList = expandedResources(); var resourceIdx = resourceList.indexOf(resource); if (resourceIdx === -1) { resourceList.push(resource); } else { resourceList.splice(resourceIdx, 1); } changeState('resources', resourceList.length > 0 ? JSON.stringify(resourceList) : null); collapse_block.collapse('toggle'); }); }); dt = $('#taskTable').DataTable({ stateSave: true, stateSaveCallback: function(settings, data) { // Save data table state to browser's hash. var state = URI.parseQuery(location.hash.replace('#', '')); if (data.search.search) { state.search__search = data.search.search; } else { delete state.search__search; } var family_search = data.columns[1].search.search; if (family_search) { state.family = family_search.substring(1, family_search.length - 1); } else { delete state.family; } if (currentFilter.taskCategory.length > 0) { state.statuses = JSON.stringify(currentFilter.taskCategory); } else { delete state.statuses; } if (data.order && data.order.length) { state.order = '' + data.order[0][0] + ',' + data.order[0][1]; } if (data.length && data.length !== 10) { // Keep in hash only if length is not default. state.length = data.length; } else { delete state.length; } if (state.filterOnServer) { state.filterOnServer = '1'; } location.hash = '#' + URI.buildQuery(state); }, stateLoadCallback: function(settings) { // Restore datatable state from browser's hash. var fragmentQuery = URI.parseQuery(location.hash.replace('#', '')); var order = []; if (fragmentQuery.order) { order = [fragmentQuery.order.split(',')]; } var family_search = {}; if (fragmentQuery.family) { family_search = {'search': '^' + fragmentQuery.family + '$', 'regex': true}; } var status_search = {}; if (fragmentQuery.statuses) { var statuses = JSON.parse(fragmentQuery.statuses); currentFilter.taskCategory = statuses; status_search = {'search': categoryQuery(statuses), 'regex': true}; } // Prepare state for datatable. var o = { order: order, // Table rows order. length: fragmentQuery.length, // Entries on page. start: 0, // Pagination initial page. time: new Date().getTime(), // Current time to help datatable.js to handle asynchronous. columns: [ {visible: true, search: status_search}, {visible: true, search: family_search}, // Name column {visible: true, search: {}}, // Details column {visible: true, search: {}}, // Priority column {visible: true, search: {}}, // Time column {visible: true, search: {}} // Actions column ], // Search input state. search: { caseInsensitive: true, search: fragmentQuery.search__search } }; return o; }, dom: 'l<"#serverSide">frtip', language: { search: 'Filter table:' }, columns: [ { data: 'category', render: function (data, type, row) { return taskCategoryIcon(data) + ' ' + data; } }, {data: 'taskName'}, { data: 'taskParams', render: function(data, type, row) { var params = JSON.parse(data); if (row.resources !== '{}') { return '
    ' + renderParams(params) + '
    ' + row.resources + '
    '; } else { return '
    ' + renderParams(params) + '
    '; } } }, {data: 'priority', width: "2em"}, {data: 'displayTime'}, { className: 'details-control', orderable: false, data: null, render: function (data, type, row) { return Mustache.render(templates.actionsTemplate, row); } } ] }); dt.on('draw', updateCurrentFilter); $('#serverSide').html('
    '); // If using server-side filter we need to updateTasks every time the filter changes $('#taskTable_filter').on('keyup paste', 'input', function () { if ($('#serverSideCheckbox')[0].checked) { clearTimeout(typingTimer); if ($(this).val) { typingTimer = setTimeout(updateTasks, 400); } } }); processHashChange(); updateTasks(); bindListEvents(); $('#taskTable tbody').on('click', 'td.details-control .showError', function () { var tr = $(this).closest('tr'); var row = dt.row( tr ); var data = row.data(); luigi.getErrorTrace(data.taskId, function(error) { showErrorTrace(error); }); } ); $('#taskTable tbody').on('click', 'td.details-control .forgiveFailures', function (ev) { var that = $(this); var tr = that.closest('tr'); var row = dt.row( tr ); var data = row.data(); luigi.forgiveFailures(data.taskId, function(data) { if (ev.altKey) { updateTasks(); // update may not be cheap } else { that.tooltip('hide'); that.remove(); } }); } ); $('#taskTable tbody').on('click', 'td.details-control .markAsDone', function (ev) { var that = $(this); var tr = that.closest('tr'); var row = dt.row( tr ); var data = row.data(); luigi.markAsDone(data.taskId, function(data) { if (ev.altKey) { updateTasks(); // update may not be cheap } else { that.tooltip('hide'); that.remove(); } }); } ); $('#taskTable tbody').on('click', 'td.details-control .re-enable-button', function (ev) { var that = $(this); luigi.reEnable(that.attr("data-task-id"), function(data) { if (ev.altKey) { updateTasks(); // update may not be cheap } else { that.tooltip('hide'); that.remove(); } }); }); $('#taskTable tbody').on('click', 'td.details-control .statusMessage', function () { var data = $(this).data(); showStatusMessage(data); }); $('#taskTable tbody').on('click', 'td.details-control .schedulerMessage', function () { var data = $(this).data(); showSchedulerMessageModal(data); }); $('.navbar-nav').on('click', 'a', function () { var tabName = $(this).data('tab'); updateSidebar(tabName); }); $('#workerList').on('show.bs.modal', '#disableWorkerModal', function (event) { var triggerButton = $(event.relatedTarget); $('#disableWorkerButton').data('trigger', triggerButton); }); $('#workerList').on('click', '#disableWorkerButton', function() { var triggerButton = $(this).data('trigger'); var worker = triggerButton.data('worker'); luigi.disableWorker(worker); // show the worker as disabled in the visualiser var box = triggerButton.parents('.box').addClass('box-solid box-default'); // remove the worker tools box.find('.box-tools').remove(); }); $('#workerList').on('click', '#btn-increment-workers', function($event) { var worker = $(this).data("worker"); var $label = $('#workerList').find('#label-n-workers[data-worker="' + worker + '"]'); var n = parseInt($label.text()); if (!isNaN(n)) { updateWorkerProcesses(worker, n + 1); } $event.preventDefault(); }); $('#workerList').on('click', '#btn-decrement-workers', function($event) { var worker = $(this).data("worker"); var $label = $('#workerList').find('#label-n-workers[data-worker="' + worker + '"]'); var n = parseInt($label.text()); if (!isNaN(n)) { updateWorkerProcesses(worker, n - 1); } $event.preventDefault(); }); $('#workerList').on('show.bs.modal', '#setWorkersModal', function($event) { $('#setWorkersButton').data('worker', $($event.relatedTarget).data('worker')); var $input = $(this).find('#setWorkersInput').on('keypress', function($event) { if (event.keyCode == 13) { $('#workerList').find('#setWorkersButton').trigger('click'); } $event.stopPropagation(); }); setTimeout(function() { $input.focus(); }.bind(this), 600); }); $('#workerList').on('hidden.bs.modal', '#setWorkersModal', function() { $(this).find('#setWorkersInput').off('keypress').val(''); }); $('#workerList').on('click', '#setWorkersButton', function($event) { var worker = $(this).data('worker'); var n = parseInt($("#setWorkersInput").val()); if (!isNaN(n)) { updateWorkerProcesses(worker, n); } $event.preventDefault(); }); $('#resourceList').on('click', '.btn-increment-resources', function($event) { $event.preventDefault(); var resource = $(this).data('resource'); var count = currentResourceCount(resource); updateResourceCount(resource, count + 1); }); $('#resourceList').on('click', '.btn-decrement-resources', function($event) { $event.preventDefault(); var resource = $(this).data('resource'); var count = currentResourceCount(resource); updateResourceCount(resource, count - 1); }); $('#resourceList').on('show.bs.modal', '#setResourcesModal', function($event) { $('#setResourcesButton').data('resource', $($event.relatedTarget).data('resource')); var $input = $(this).find('#setResourcesInput').on('keypress', function($event) { if (event.keyCode == 13) { $('#resourceList').find('#setResourcesButton').trigger('click'); } $event.stopPropagation(); }); setTimeout(function() { $input.focus(); }.bind(this), 600); }); $('#resourceList').on('hidden.bs.modal', '#setResourcesModal', function() { $(this).find('#setResourcesInput').off('keypress').val(''); }); $('#resourceList').on('click', '#setResourcesButton', function($event) { var resource = $(this).data('resource'); var n = parseInt($("#setResourcesInput").val()); updateResourceCount(resource, n); $event.preventDefault(); }); $('.js-nav-link').click(function(e) { // User followed tab from navigation link. Copy state from fields to hash. e.preventDefault(); var state = {}; var tabId = $(this).attr('data-tab'); if (tabId == 'taskList') { var order = dt.order(); var search = dt.search(); state.tab = 'tasks'; if ($('select[name=taskTable_length]').val() !== '10') { // Add length to hash only if the value is not default. state.length = $('select[name=taskTable_length]').val(); } if ($('#serverSideCheckbox').is(':checked')) { state.filterOnServer = '1'; } var family = $('#familySidebar li.active').attr('data-task'); if (family) { state.family = family; } else { delete state.family; } if (currentFilter.taskCategory.length > 0) { state.statuses = JSON.stringify(currentFilter.taskCategory); } else { delete state.statuses; } if (search) { state.search__search = search; } if (order.length > 0) { state.order = '' + order[0][0] + ',' + order[0][1]; } } else if (tabId == 'dependencyGraph') { state.tab = 'graph'; // Get state from fields. if ($('#hideDoneCheckbox').is(':checked')) { state.hideDone = '1'; } if ($('#idTaskForm input.search-query').val()) { state.taskId = $('#idTaskForm input.search-query').val(); } if ($('#invertCheckbox').is(':checked')) { state.invert = '1'; } state.visType = $('input[name=vis-type]:checked').val(); } else if (tabId == 'workerList') { state.tab = 'workers'; } else if (tabId == 'resourceList') { state.resources = JSON.stringify(expandedResources()); state.tab = 'resources'; } location.hash = '#' + URI.buildQuery(state); }); processHashChange(); }); } ================================================ FILE: luigi/static/visualiser/lib/URI/1.18.2/URI.js ================================================ /*! * URI.js - Mutating URLs * * Version: 1.18.2 * * Author: Rodney Rehm * Web: http://medialize.github.io/URI.js/ * * Licensed under * MIT License http://www.opensource.org/licenses/mit-license * */ (function (root, factory) { 'use strict'; // https://github.com/umdjs/umd/blob/master/returnExports.js if (typeof exports === 'object') { // Node module.exports = factory(require('./punycode'), require('./IPv6'), require('./SecondLevelDomains')); } else if (typeof define === 'function' && define.amd) { // AMD. Register as an anonymous module. define(['./punycode', './IPv6', './SecondLevelDomains'], factory); } else { // Browser globals (root is window) root.URI = factory(root.punycode, root.IPv6, root.SecondLevelDomains, root); } }(this, function (punycode, IPv6, SLD, root) { 'use strict'; /*global location, escape, unescape */ // FIXME: v2.0.0 renamce non-camelCase properties to uppercase /*jshint camelcase: false */ // save current URI variable, if any var _URI = root && root.URI; function URI(url, base) { var _urlSupplied = arguments.length >= 1; var _baseSupplied = arguments.length >= 2; // Allow instantiation without the 'new' keyword if (!(this instanceof URI)) { if (_urlSupplied) { if (_baseSupplied) { return new URI(url, base); } return new URI(url); } return new URI(); } if (url === undefined) { if (_urlSupplied) { throw new TypeError('undefined is not a valid argument for URI'); } if (typeof location !== 'undefined') { url = location.href + ''; } else { url = ''; } } this.href(url); // resolve to base according to http://dvcs.w3.org/hg/url/raw-file/tip/Overview.html#constructor if (base !== undefined) { return this.absoluteTo(base); } return this; } URI.version = '1.18.2'; var p = URI.prototype; var hasOwn = Object.prototype.hasOwnProperty; function escapeRegEx(string) { // https://github.com/medialize/URI.js/commit/85ac21783c11f8ccab06106dba9735a31a86924d#commitcomment-821963 return string.replace(/([.*+?^=!:${}()|[\]\/\\])/g, '\\$1'); } function getType(value) { // IE8 doesn't return [Object Undefined] but [Object Object] for undefined value if (value === undefined) { return 'Undefined'; } return String(Object.prototype.toString.call(value)).slice(8, -1); } function isArray(obj) { return getType(obj) === 'Array'; } function filterArrayValues(data, value) { var lookup = {}; var i, length; if (getType(value) === 'RegExp') { lookup = null; } else if (isArray(value)) { for (i = 0, length = value.length; i < length; i++) { lookup[value[i]] = true; } } else { lookup[value] = true; } for (i = 0, length = data.length; i < length; i++) { /*jshint laxbreak: true */ var _match = lookup && lookup[data[i]] !== undefined || !lookup && value.test(data[i]); /*jshint laxbreak: false */ if (_match) { data.splice(i, 1); length--; i--; } } return data; } function arrayContains(list, value) { var i, length; // value may be string, number, array, regexp if (isArray(value)) { // Note: this can be optimized to O(n) (instead of current O(m * n)) for (i = 0, length = value.length; i < length; i++) { if (!arrayContains(list, value[i])) { return false; } } return true; } var _type = getType(value); for (i = 0, length = list.length; i < length; i++) { if (_type === 'RegExp') { if (typeof list[i] === 'string' && list[i].match(value)) { return true; } } else if (list[i] === value) { return true; } } return false; } function arraysEqual(one, two) { if (!isArray(one) || !isArray(two)) { return false; } // arrays can't be equal if they have different amount of content if (one.length !== two.length) { return false; } one.sort(); two.sort(); for (var i = 0, l = one.length; i < l; i++) { if (one[i] !== two[i]) { return false; } } return true; } function trimSlashes(text) { var trim_expression = /^\/+|\/+$/g; return text.replace(trim_expression, ''); } URI._parts = function() { return { protocol: null, username: null, password: null, hostname: null, urn: null, port: null, path: null, query: null, fragment: null, // state duplicateQueryParameters: URI.duplicateQueryParameters, escapeQuerySpace: URI.escapeQuerySpace }; }; // state: allow duplicate query parameters (a=1&a=1) URI.duplicateQueryParameters = false; // state: replaces + with %20 (space in query strings) URI.escapeQuerySpace = true; // static properties URI.protocol_expression = /^[a-z][a-z0-9.+-]*$/i; URI.idn_expression = /[^a-z0-9\.-]/i; URI.punycode_expression = /(xn--)/i; // well, 333.444.555.666 matches, but it sure ain't no IPv4 - do we care? URI.ip4_expression = /^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$/; // credits to Rich Brown // source: http://forums.intermapper.com/viewtopic.php?p=1096#1096 // specification: http://www.ietf.org/rfc/rfc4291.txt URI.ip6_expression = /^\s*((([0-9A-Fa-f]{1,4}:){7}([0-9A-Fa-f]{1,4}|:))|(([0-9A-Fa-f]{1,4}:){6}(:[0-9A-Fa-f]{1,4}|((25[0-5]|2[0-4]\d|1\d\d|[1-9]?\d)(\.(25[0-5]|2[0-4]\d|1\d\d|[1-9]?\d)){3})|:))|(([0-9A-Fa-f]{1,4}:){5}(((:[0-9A-Fa-f]{1,4}){1,2})|:((25[0-5]|2[0-4]\d|1\d\d|[1-9]?\d)(\.(25[0-5]|2[0-4]\d|1\d\d|[1-9]?\d)){3})|:))|(([0-9A-Fa-f]{1,4}:){4}(((:[0-9A-Fa-f]{1,4}){1,3})|((:[0-9A-Fa-f]{1,4})?:((25[0-5]|2[0-4]\d|1\d\d|[1-9]?\d)(\.(25[0-5]|2[0-4]\d|1\d\d|[1-9]?\d)){3}))|:))|(([0-9A-Fa-f]{1,4}:){3}(((:[0-9A-Fa-f]{1,4}){1,4})|((:[0-9A-Fa-f]{1,4}){0,2}:((25[0-5]|2[0-4]\d|1\d\d|[1-9]?\d)(\.(25[0-5]|2[0-4]\d|1\d\d|[1-9]?\d)){3}))|:))|(([0-9A-Fa-f]{1,4}:){2}(((:[0-9A-Fa-f]{1,4}){1,5})|((:[0-9A-Fa-f]{1,4}){0,3}:((25[0-5]|2[0-4]\d|1\d\d|[1-9]?\d)(\.(25[0-5]|2[0-4]\d|1\d\d|[1-9]?\d)){3}))|:))|(([0-9A-Fa-f]{1,4}:){1}(((:[0-9A-Fa-f]{1,4}){1,6})|((:[0-9A-Fa-f]{1,4}){0,4}:((25[0-5]|2[0-4]\d|1\d\d|[1-9]?\d)(\.(25[0-5]|2[0-4]\d|1\d\d|[1-9]?\d)){3}))|:))|(:(((:[0-9A-Fa-f]{1,4}){1,7})|((:[0-9A-Fa-f]{1,4}){0,5}:((25[0-5]|2[0-4]\d|1\d\d|[1-9]?\d)(\.(25[0-5]|2[0-4]\d|1\d\d|[1-9]?\d)){3}))|:)))(%.+)?\s*$/; // expression used is "gruber revised" (@gruber v2) determined to be the // best solution in a regex-golf we did a couple of ages ago at // * http://mathiasbynens.be/demo/url-regex // * http://rodneyrehm.de/t/url-regex.html URI.find_uri_expression = /\b((?:[a-z][\w-]+:(?:\/{1,3}|[a-z0-9%])|www\d{0,3}[.]|[a-z0-9.\-]+[.][a-z]{2,4}\/)(?:[^\s()<>]+|\(([^\s()<>]+|(\([^\s()<>]+\)))*\))+(?:\(([^\s()<>]+|(\([^\s()<>]+\)))*\)|[^\s`!()\[\]{};:'".,<>?«»“”‘’]))/ig; URI.findUri = { // valid "scheme://" or "www." start: /\b(?:([a-z][a-z0-9.+-]*:\/\/)|www\.)/gi, // everything up to the next whitespace end: /[\s\r\n]|$/, // trim trailing punctuation captured by end RegExp trim: /[`!()\[\]{};:'".,<>?«»“”„‘’]+$/ }; // http://www.iana.org/assignments/uri-schemes.html // http://en.wikipedia.org/wiki/List_of_TCP_and_UDP_port_numbers#Well-known_ports URI.defaultPorts = { http: '80', https: '443', ftp: '21', gopher: '70', ws: '80', wss: '443' }; // allowed hostname characters according to RFC 3986 // ALPHA DIGIT "-" "." "_" "~" "!" "$" "&" "'" "(" ")" "*" "+" "," ";" "=" %encoded // I've never seen a (non-IDN) hostname other than: ALPHA DIGIT . - URI.invalid_hostname_characters = /[^a-zA-Z0-9\.-]/; // map DOM Elements to their URI attribute URI.domAttributes = { 'a': 'href', 'blockquote': 'cite', 'link': 'href', 'base': 'href', 'script': 'src', 'form': 'action', 'img': 'src', 'area': 'href', 'iframe': 'src', 'embed': 'src', 'source': 'src', 'track': 'src', 'input': 'src', // but only if type="image" 'audio': 'src', 'video': 'src' }; URI.getDomAttribute = function(node) { if (!node || !node.nodeName) { return undefined; } var nodeName = node.nodeName.toLowerCase(); // should only expose src for type="image" if (nodeName === 'input' && node.type !== 'image') { return undefined; } return URI.domAttributes[nodeName]; }; function escapeForDumbFirefox36(value) { // https://github.com/medialize/URI.js/issues/91 return escape(value); } // encoding / decoding according to RFC3986 function strictEncodeURIComponent(string) { // see https://developer.mozilla.org/en-US/docs/JavaScript/Reference/Global_Objects/encodeURIComponent return encodeURIComponent(string) .replace(/[!'()*]/g, escapeForDumbFirefox36) .replace(/\*/g, '%2A'); } URI.encode = strictEncodeURIComponent; URI.decode = decodeURIComponent; URI.iso8859 = function() { URI.encode = escape; URI.decode = unescape; }; URI.unicode = function() { URI.encode = strictEncodeURIComponent; URI.decode = decodeURIComponent; }; URI.characters = { pathname: { encode: { // RFC3986 2.1: For consistency, URI producers and normalizers should // use uppercase hexadecimal digits for all percent-encodings. expression: /%(24|26|2B|2C|3B|3D|3A|40)/ig, map: { // -._~!'()* '%24': '$', '%26': '&', '%2B': '+', '%2C': ',', '%3B': ';', '%3D': '=', '%3A': ':', '%40': '@' } }, decode: { expression: /[\/\?#]/g, map: { '/': '%2F', '?': '%3F', '#': '%23' } } }, reserved: { encode: { // RFC3986 2.1: For consistency, URI producers and normalizers should // use uppercase hexadecimal digits for all percent-encodings. expression: /%(21|23|24|26|27|28|29|2A|2B|2C|2F|3A|3B|3D|3F|40|5B|5D)/ig, map: { // gen-delims '%3A': ':', '%2F': '/', '%3F': '?', '%23': '#', '%5B': '[', '%5D': ']', '%40': '@', // sub-delims '%21': '!', '%24': '$', '%26': '&', '%27': '\'', '%28': '(', '%29': ')', '%2A': '*', '%2B': '+', '%2C': ',', '%3B': ';', '%3D': '=' } } }, urnpath: { // The characters under `encode` are the characters called out by RFC 2141 as being acceptable // for usage in a URN. RFC2141 also calls out "-", ".", and "_" as acceptable characters, but // these aren't encoded by encodeURIComponent, so we don't have to call them out here. Also // note that the colon character is not featured in the encoding map; this is because URI.js // gives the colons in URNs semantic meaning as the delimiters of path segements, and so it // should not appear unencoded in a segment itself. // See also the note above about RFC3986 and capitalalized hex digits. encode: { expression: /%(21|24|27|28|29|2A|2B|2C|3B|3D|40)/ig, map: { '%21': '!', '%24': '$', '%27': '\'', '%28': '(', '%29': ')', '%2A': '*', '%2B': '+', '%2C': ',', '%3B': ';', '%3D': '=', '%40': '@' } }, // These characters are the characters called out by RFC2141 as "reserved" characters that // should never appear in a URN, plus the colon character (see note above). decode: { expression: /[\/\?#:]/g, map: { '/': '%2F', '?': '%3F', '#': '%23', ':': '%3A' } } } }; URI.encodeQuery = function(string, escapeQuerySpace) { var escaped = URI.encode(string + ''); if (escapeQuerySpace === undefined) { escapeQuerySpace = URI.escapeQuerySpace; } return escapeQuerySpace ? escaped.replace(/%20/g, '+') : escaped; }; URI.decodeQuery = function(string, escapeQuerySpace) { string += ''; if (escapeQuerySpace === undefined) { escapeQuerySpace = URI.escapeQuerySpace; } try { return URI.decode(escapeQuerySpace ? string.replace(/\+/g, '%20') : string); } catch(e) { // we're not going to mess with weird encodings, // give up and return the undecoded original string // see https://github.com/medialize/URI.js/issues/87 // see https://github.com/medialize/URI.js/issues/92 return string; } }; // generate encode/decode path functions var _parts = {'encode':'encode', 'decode':'decode'}; var _part; var generateAccessor = function(_group, _part) { return function(string) { try { return URI[_part](string + '').replace(URI.characters[_group][_part].expression, function(c) { return URI.characters[_group][_part].map[c]; }); } catch (e) { // we're not going to mess with weird encodings, // give up and return the undecoded original string // see https://github.com/medialize/URI.js/issues/87 // see https://github.com/medialize/URI.js/issues/92 return string; } }; }; for (_part in _parts) { URI[_part + 'PathSegment'] = generateAccessor('pathname', _parts[_part]); URI[_part + 'UrnPathSegment'] = generateAccessor('urnpath', _parts[_part]); } var generateSegmentedPathFunction = function(_sep, _codingFuncName, _innerCodingFuncName) { return function(string) { // Why pass in names of functions, rather than the function objects themselves? The // definitions of some functions (but in particular, URI.decode) will occasionally change due // to URI.js having ISO8859 and Unicode modes. Passing in the name and getting it will ensure // that the functions we use here are "fresh". var actualCodingFunc; if (!_innerCodingFuncName) { actualCodingFunc = URI[_codingFuncName]; } else { actualCodingFunc = function(string) { return URI[_codingFuncName](URI[_innerCodingFuncName](string)); }; } var segments = (string + '').split(_sep); for (var i = 0, length = segments.length; i < length; i++) { segments[i] = actualCodingFunc(segments[i]); } return segments.join(_sep); }; }; // This takes place outside the above loop because we don't want, e.g., encodeUrnPath functions. URI.decodePath = generateSegmentedPathFunction('/', 'decodePathSegment'); URI.decodeUrnPath = generateSegmentedPathFunction(':', 'decodeUrnPathSegment'); URI.recodePath = generateSegmentedPathFunction('/', 'encodePathSegment', 'decode'); URI.recodeUrnPath = generateSegmentedPathFunction(':', 'encodeUrnPathSegment', 'decode'); URI.encodeReserved = generateAccessor('reserved', 'encode'); URI.parse = function(string, parts) { var pos; if (!parts) { parts = {}; } // [protocol"://"[username[":"password]"@"]hostname[":"port]"/"?][path]["?"querystring]["#"fragment] // extract fragment pos = string.indexOf('#'); if (pos > -1) { // escaping? parts.fragment = string.substring(pos + 1) || null; string = string.substring(0, pos); } // extract query pos = string.indexOf('?'); if (pos > -1) { // escaping? parts.query = string.substring(pos + 1) || null; string = string.substring(0, pos); } // extract protocol if (string.substring(0, 2) === '//') { // relative-scheme parts.protocol = null; string = string.substring(2); // extract "user:pass@host:port" string = URI.parseAuthority(string, parts); } else { pos = string.indexOf(':'); if (pos > -1) { parts.protocol = string.substring(0, pos) || null; if (parts.protocol && !parts.protocol.match(URI.protocol_expression)) { // : may be within the path parts.protocol = undefined; } else if (string.substring(pos + 1, pos + 3) === '//') { string = string.substring(pos + 3); // extract "user:pass@host:port" string = URI.parseAuthority(string, parts); } else { string = string.substring(pos + 1); parts.urn = true; } } } // what's left must be the path parts.path = string; // and we're done return parts; }; URI.parseHost = function(string, parts) { // Copy chrome, IE, opera backslash-handling behavior. // Back slashes before the query string get converted to forward slashes // See: https://github.com/joyent/node/blob/386fd24f49b0e9d1a8a076592a404168faeecc34/lib/url.js#L115-L124 // See: https://code.google.com/p/chromium/issues/detail?id=25916 // https://github.com/medialize/URI.js/pull/233 string = string.replace(/\\/g, '/'); // extract host:port var pos = string.indexOf('/'); var bracketPos; var t; if (pos === -1) { pos = string.length; } if (string.charAt(0) === '[') { // IPv6 host - http://tools.ietf.org/html/draft-ietf-6man-text-addr-representation-04#section-6 // I claim most client software breaks on IPv6 anyways. To simplify things, URI only accepts // IPv6+port in the format [2001:db8::1]:80 (for the time being) bracketPos = string.indexOf(']'); parts.hostname = string.substring(1, bracketPos) || null; parts.port = string.substring(bracketPos + 2, pos) || null; if (parts.port === '/') { parts.port = null; } } else { var firstColon = string.indexOf(':'); var firstSlash = string.indexOf('/'); var nextColon = string.indexOf(':', firstColon + 1); if (nextColon !== -1 && (firstSlash === -1 || nextColon < firstSlash)) { // IPv6 host contains multiple colons - but no port // this notation is actually not allowed by RFC 3986, but we're a liberal parser parts.hostname = string.substring(0, pos) || null; parts.port = null; } else { t = string.substring(0, pos).split(':'); parts.hostname = t[0] || null; parts.port = t[1] || null; } } if (parts.hostname && string.substring(pos).charAt(0) !== '/') { pos++; string = '/' + string; } return string.substring(pos) || '/'; }; URI.parseAuthority = function(string, parts) { string = URI.parseUserinfo(string, parts); return URI.parseHost(string, parts); }; URI.parseUserinfo = function(string, parts) { // extract username:password var firstSlash = string.indexOf('/'); var pos = string.lastIndexOf('@', firstSlash > -1 ? firstSlash : string.length - 1); var t; // authority@ must come before /path if (pos > -1 && (firstSlash === -1 || pos < firstSlash)) { t = string.substring(0, pos).split(':'); parts.username = t[0] ? URI.decode(t[0]) : null; t.shift(); parts.password = t[0] ? URI.decode(t.join(':')) : null; string = string.substring(pos + 1); } else { parts.username = null; parts.password = null; } return string; }; URI.parseQuery = function(string, escapeQuerySpace) { if (!string) { return {}; } // throw out the funky business - "?"[name"="value"&"]+ string = string.replace(/&+/g, '&').replace(/^\?*&*|&+$/g, ''); if (!string) { return {}; } var items = {}; var splits = string.split('&'); var length = splits.length; var v, name, value; for (var i = 0; i < length; i++) { v = splits[i].split('='); name = URI.decodeQuery(v.shift(), escapeQuerySpace); // no "=" is null according to http://dvcs.w3.org/hg/url/raw-file/tip/Overview.html#collect-url-parameters value = v.length ? URI.decodeQuery(v.join('='), escapeQuerySpace) : null; if (hasOwn.call(items, name)) { if (typeof items[name] === 'string' || items[name] === null) { items[name] = [items[name]]; } items[name].push(value); } else { items[name] = value; } } return items; }; URI.build = function(parts) { var t = ''; if (parts.protocol) { t += parts.protocol + ':'; } if (!parts.urn && (t || parts.hostname)) { t += '//'; } t += (URI.buildAuthority(parts) || ''); if (typeof parts.path === 'string') { if (parts.path.charAt(0) !== '/' && typeof parts.hostname === 'string') { t += '/'; } t += parts.path; } if (typeof parts.query === 'string' && parts.query) { t += '?' + parts.query; } if (typeof parts.fragment === 'string' && parts.fragment) { t += '#' + parts.fragment; } return t; }; URI.buildHost = function(parts) { var t = ''; if (!parts.hostname) { return ''; } else if (URI.ip6_expression.test(parts.hostname)) { t += '[' + parts.hostname + ']'; } else { t += parts.hostname; } if (parts.port) { t += ':' + parts.port; } return t; }; URI.buildAuthority = function(parts) { return URI.buildUserinfo(parts) + URI.buildHost(parts); }; URI.buildUserinfo = function(parts) { var t = ''; if (parts.username) { t += URI.encode(parts.username); } if (parts.password) { t += ':' + URI.encode(parts.password); } if (t) { t += '@'; } return t; }; URI.buildQuery = function(data, duplicateQueryParameters, escapeQuerySpace) { // according to http://tools.ietf.org/html/rfc3986 or http://labs.apache.org/webarch/uri/rfc/rfc3986.html // being »-._~!$&'()*+,;=:@/?« %HEX and alnum are allowed // the RFC explicitly states ?/foo being a valid use case, no mention of parameter syntax! // URI.js treats the query string as being application/x-www-form-urlencoded // see http://www.w3.org/TR/REC-html40/interact/forms.html#form-content-type var t = ''; var unique, key, i, length; for (key in data) { if (hasOwn.call(data, key) && key) { if (isArray(data[key])) { unique = {}; for (i = 0, length = data[key].length; i < length; i++) { if (data[key][i] !== undefined && unique[data[key][i] + ''] === undefined) { t += '&' + URI.buildQueryParameter(key, data[key][i], escapeQuerySpace); if (duplicateQueryParameters !== true) { unique[data[key][i] + ''] = true; } } } } else if (data[key] !== undefined) { t += '&' + URI.buildQueryParameter(key, data[key], escapeQuerySpace); } } } return t.substring(1); }; URI.buildQueryParameter = function(name, value, escapeQuerySpace) { // http://www.w3.org/TR/REC-html40/interact/forms.html#form-content-type -- application/x-www-form-urlencoded // don't append "=" for null values, according to http://dvcs.w3.org/hg/url/raw-file/tip/Overview.html#url-parameter-serialization return URI.encodeQuery(name, escapeQuerySpace) + (value !== null ? '=' + URI.encodeQuery(value, escapeQuerySpace) : ''); }; URI.addQuery = function(data, name, value) { if (typeof name === 'object') { for (var key in name) { if (hasOwn.call(name, key)) { URI.addQuery(data, key, name[key]); } } } else if (typeof name === 'string') { if (data[name] === undefined) { data[name] = value; return; } else if (typeof data[name] === 'string') { data[name] = [data[name]]; } if (!isArray(value)) { value = [value]; } data[name] = (data[name] || []).concat(value); } else { throw new TypeError('URI.addQuery() accepts an object, string as the name parameter'); } }; URI.removeQuery = function(data, name, value) { var i, length, key; if (isArray(name)) { for (i = 0, length = name.length; i < length; i++) { data[name[i]] = undefined; } } else if (getType(name) === 'RegExp') { for (key in data) { if (name.test(key)) { data[key] = undefined; } } } else if (typeof name === 'object') { for (key in name) { if (hasOwn.call(name, key)) { URI.removeQuery(data, key, name[key]); } } } else if (typeof name === 'string') { if (value !== undefined) { if (getType(value) === 'RegExp') { if (!isArray(data[name]) && value.test(data[name])) { data[name] = undefined; } else { data[name] = filterArrayValues(data[name], value); } } else if (data[name] === String(value) && (!isArray(value) || value.length === 1)) { data[name] = undefined; } else if (isArray(data[name])) { data[name] = filterArrayValues(data[name], value); } } else { data[name] = undefined; } } else { throw new TypeError('URI.removeQuery() accepts an object, string, RegExp as the first parameter'); } }; URI.hasQuery = function(data, name, value, withinArray) { switch (getType(name)) { case 'String': // Nothing to do here break; case 'RegExp': for (var key in data) { if (hasOwn.call(data, key)) { if (name.test(key) && (value === undefined || URI.hasQuery(data, key, value))) { return true; } } } return false; case 'Object': for (var _key in name) { if (hasOwn.call(name, _key)) { if (!URI.hasQuery(data, _key, name[_key])) { return false; } } } return true; default: throw new TypeError('URI.hasQuery() accepts a string, regular expression or object as the name parameter'); } switch (getType(value)) { case 'Undefined': // true if exists (but may be empty) return name in data; // data[name] !== undefined; case 'Boolean': // true if exists and non-empty var _booly = Boolean(isArray(data[name]) ? data[name].length : data[name]); return value === _booly; case 'Function': // allow complex comparison return !!value(data[name], name, data); case 'Array': if (!isArray(data[name])) { return false; } var op = withinArray ? arrayContains : arraysEqual; return op(data[name], value); case 'RegExp': if (!isArray(data[name])) { return Boolean(data[name] && data[name].match(value)); } if (!withinArray) { return false; } return arrayContains(data[name], value); case 'Number': value = String(value); /* falls through */ case 'String': if (!isArray(data[name])) { return data[name] === value; } if (!withinArray) { return false; } return arrayContains(data[name], value); default: throw new TypeError('URI.hasQuery() accepts undefined, boolean, string, number, RegExp, Function as the value parameter'); } }; URI.joinPaths = function() { var input = []; var segments = []; var nonEmptySegments = 0; for (var i = 0; i < arguments.length; i++) { var url = new URI(arguments[i]); input.push(url); var _segments = url.segment(); for (var s = 0; s < _segments.length; s++) { if (typeof _segments[s] === 'string') { segments.push(_segments[s]); } if (_segments[s]) { nonEmptySegments++; } } } if (!segments.length || !nonEmptySegments) { return new URI(''); } var uri = new URI('').segment(segments); if (input[0].path() === '' || input[0].path().slice(0, 1) === '/') { uri.path('/' + uri.path()); } return uri.normalize(); }; URI.commonPath = function(one, two) { var length = Math.min(one.length, two.length); var pos; // find first non-matching character for (pos = 0; pos < length; pos++) { if (one.charAt(pos) !== two.charAt(pos)) { pos--; break; } } if (pos < 1) { return one.charAt(0) === two.charAt(0) && one.charAt(0) === '/' ? '/' : ''; } // revert to last / if (one.charAt(pos) !== '/' || two.charAt(pos) !== '/') { pos = one.substring(0, pos).lastIndexOf('/'); } return one.substring(0, pos + 1); }; URI.withinString = function(string, callback, options) { options || (options = {}); var _start = options.start || URI.findUri.start; var _end = options.end || URI.findUri.end; var _trim = options.trim || URI.findUri.trim; var _attributeOpen = /[a-z0-9-]=["']?$/i; _start.lastIndex = 0; while (true) { var match = _start.exec(string); if (!match) { break; } var start = match.index; if (options.ignoreHtml) { // attribut(e=["']?$) var attributeOpen = string.slice(Math.max(start - 3, 0), start); if (attributeOpen && _attributeOpen.test(attributeOpen)) { continue; } } var end = start + string.slice(start).search(_end); var slice = string.slice(start, end).replace(_trim, ''); if (options.ignore && options.ignore.test(slice)) { continue; } end = start + slice.length; var result = callback(slice, start, end, string); if (result === undefined) { _start.lastIndex = end; continue; } result = String(result); string = string.slice(0, start) + result + string.slice(end); _start.lastIndex = start + result.length; } _start.lastIndex = 0; return string; }; URI.ensureValidHostname = function(v) { // Theoretically URIs allow percent-encoding in Hostnames (according to RFC 3986) // they are not part of DNS and therefore ignored by URI.js if (v.match(URI.invalid_hostname_characters)) { // test punycode if (!punycode) { throw new TypeError('Hostname "' + v + '" contains characters other than [A-Z0-9.-] and Punycode.js is not available'); } if (punycode.toASCII(v).match(URI.invalid_hostname_characters)) { throw new TypeError('Hostname "' + v + '" contains characters other than [A-Z0-9.-]'); } } }; // noConflict URI.noConflict = function(removeAll) { if (removeAll) { var unconflicted = { URI: this.noConflict() }; if (root.URITemplate && typeof root.URITemplate.noConflict === 'function') { unconflicted.URITemplate = root.URITemplate.noConflict(); } if (root.IPv6 && typeof root.IPv6.noConflict === 'function') { unconflicted.IPv6 = root.IPv6.noConflict(); } if (root.SecondLevelDomains && typeof root.SecondLevelDomains.noConflict === 'function') { unconflicted.SecondLevelDomains = root.SecondLevelDomains.noConflict(); } return unconflicted; } else if (root.URI === this) { root.URI = _URI; } return this; }; p.build = function(deferBuild) { if (deferBuild === true) { this._deferred_build = true; } else if (deferBuild === undefined || this._deferred_build) { this._string = URI.build(this._parts); this._deferred_build = false; } return this; }; p.clone = function() { return new URI(this); }; p.valueOf = p.toString = function() { return this.build(false)._string; }; function generateSimpleAccessor(_part){ return function(v, build) { if (v === undefined) { return this._parts[_part] || ''; } else { this._parts[_part] = v || null; this.build(!build); return this; } }; } function generatePrefixAccessor(_part, _key){ return function(v, build) { if (v === undefined) { return this._parts[_part] || ''; } else { if (v !== null) { v = v + ''; if (v.charAt(0) === _key) { v = v.substring(1); } } this._parts[_part] = v; this.build(!build); return this; } }; } p.protocol = generateSimpleAccessor('protocol'); p.username = generateSimpleAccessor('username'); p.password = generateSimpleAccessor('password'); p.hostname = generateSimpleAccessor('hostname'); p.port = generateSimpleAccessor('port'); p.query = generatePrefixAccessor('query', '?'); p.fragment = generatePrefixAccessor('fragment', '#'); p.search = function(v, build) { var t = this.query(v, build); return typeof t === 'string' && t.length ? ('?' + t) : t; }; p.hash = function(v, build) { var t = this.fragment(v, build); return typeof t === 'string' && t.length ? ('#' + t) : t; }; p.pathname = function(v, build) { if (v === undefined || v === true) { var res = this._parts.path || (this._parts.hostname ? '/' : ''); return v ? (this._parts.urn ? URI.decodeUrnPath : URI.decodePath)(res) : res; } else { if (this._parts.urn) { this._parts.path = v ? URI.recodeUrnPath(v) : ''; } else { this._parts.path = v ? URI.recodePath(v) : '/'; } this.build(!build); return this; } }; p.path = p.pathname; p.href = function(href, build) { var key; if (href === undefined) { return this.toString(); } this._string = ''; this._parts = URI._parts(); var _URI = href instanceof URI; var _object = typeof href === 'object' && (href.hostname || href.path || href.pathname); if (href.nodeName) { var attribute = URI.getDomAttribute(href); href = href[attribute] || ''; _object = false; } // window.location is reported to be an object, but it's not the sort // of object we're looking for: // * location.protocol ends with a colon // * location.query != object.search // * location.hash != object.fragment // simply serializing the unknown object should do the trick // (for location, not for everything...) if (!_URI && _object && href.pathname !== undefined) { href = href.toString(); } if (typeof href === 'string' || href instanceof String) { this._parts = URI.parse(String(href), this._parts); } else if (_URI || _object) { var src = _URI ? href._parts : href; for (key in src) { if (hasOwn.call(this._parts, key)) { this._parts[key] = src[key]; } } } else { throw new TypeError('invalid input'); } this.build(!build); return this; }; // identification accessors p.is = function(what) { var ip = false; var ip4 = false; var ip6 = false; var name = false; var sld = false; var idn = false; var punycode = false; var relative = !this._parts.urn; if (this._parts.hostname) { relative = false; ip4 = URI.ip4_expression.test(this._parts.hostname); ip6 = URI.ip6_expression.test(this._parts.hostname); ip = ip4 || ip6; name = !ip; sld = name && SLD && SLD.has(this._parts.hostname); idn = name && URI.idn_expression.test(this._parts.hostname); punycode = name && URI.punycode_expression.test(this._parts.hostname); } switch (what.toLowerCase()) { case 'relative': return relative; case 'absolute': return !relative; // hostname identification case 'domain': case 'name': return name; case 'sld': return sld; case 'ip': return ip; case 'ip4': case 'ipv4': case 'inet4': return ip4; case 'ip6': case 'ipv6': case 'inet6': return ip6; case 'idn': return idn; case 'url': return !this._parts.urn; case 'urn': return !!this._parts.urn; case 'punycode': return punycode; } return null; }; // component specific input validation var _protocol = p.protocol; var _port = p.port; var _hostname = p.hostname; p.protocol = function(v, build) { if (v !== undefined) { if (v) { // accept trailing :// v = v.replace(/:(\/\/)?$/, ''); if (!v.match(URI.protocol_expression)) { throw new TypeError('Protocol "' + v + '" contains characters other than [A-Z0-9.+-] or doesn\'t start with [A-Z]'); } } } return _protocol.call(this, v, build); }; p.scheme = p.protocol; p.port = function(v, build) { if (this._parts.urn) { return v === undefined ? '' : this; } if (v !== undefined) { if (v === 0) { v = null; } if (v) { v += ''; if (v.charAt(0) === ':') { v = v.substring(1); } if (v.match(/[^0-9]/)) { throw new TypeError('Port "' + v + '" contains characters other than [0-9]'); } } } return _port.call(this, v, build); }; p.hostname = function(v, build) { if (this._parts.urn) { return v === undefined ? '' : this; } if (v !== undefined) { var x = {}; var res = URI.parseHost(v, x); if (res !== '/') { throw new TypeError('Hostname "' + v + '" contains characters other than [A-Z0-9.-]'); } v = x.hostname; } return _hostname.call(this, v, build); }; // compound accessors p.origin = function(v, build) { if (this._parts.urn) { return v === undefined ? '' : this; } if (v === undefined) { var protocol = this.protocol(); var authority = this.authority(); if (!authority) { return ''; } return (protocol ? protocol + '://' : '') + this.authority(); } else { var origin = URI(v); this .protocol(origin.protocol()) .authority(origin.authority()) .build(!build); return this; } }; p.host = function(v, build) { if (this._parts.urn) { return v === undefined ? '' : this; } if (v === undefined) { return this._parts.hostname ? URI.buildHost(this._parts) : ''; } else { var res = URI.parseHost(v, this._parts); if (res !== '/') { throw new TypeError('Hostname "' + v + '" contains characters other than [A-Z0-9.-]'); } this.build(!build); return this; } }; p.authority = function(v, build) { if (this._parts.urn) { return v === undefined ? '' : this; } if (v === undefined) { return this._parts.hostname ? URI.buildAuthority(this._parts) : ''; } else { var res = URI.parseAuthority(v, this._parts); if (res !== '/') { throw new TypeError('Hostname "' + v + '" contains characters other than [A-Z0-9.-]'); } this.build(!build); return this; } }; p.userinfo = function(v, build) { if (this._parts.urn) { return v === undefined ? '' : this; } if (v === undefined) { var t = URI.buildUserinfo(this._parts); return t ? t.substring(0, t.length -1) : t; } else { if (v[v.length-1] !== '@') { v += '@'; } URI.parseUserinfo(v, this._parts); this.build(!build); return this; } }; p.resource = function(v, build) { var parts; if (v === undefined) { return this.path() + this.search() + this.hash(); } parts = URI.parse(v); this._parts.path = parts.path; this._parts.query = parts.query; this._parts.fragment = parts.fragment; this.build(!build); return this; }; // fraction accessors p.subdomain = function(v, build) { if (this._parts.urn) { return v === undefined ? '' : this; } // convenience, return "www" from "www.example.org" if (v === undefined) { if (!this._parts.hostname || this.is('IP')) { return ''; } // grab domain and add another segment var end = this._parts.hostname.length - this.domain().length - 1; return this._parts.hostname.substring(0, end) || ''; } else { var e = this._parts.hostname.length - this.domain().length; var sub = this._parts.hostname.substring(0, e); var replace = new RegExp('^' + escapeRegEx(sub)); if (v && v.charAt(v.length - 1) !== '.') { v += '.'; } if (v) { URI.ensureValidHostname(v); } this._parts.hostname = this._parts.hostname.replace(replace, v); this.build(!build); return this; } }; p.domain = function(v, build) { if (this._parts.urn) { return v === undefined ? '' : this; } if (typeof v === 'boolean') { build = v; v = undefined; } // convenience, return "example.org" from "www.example.org" if (v === undefined) { if (!this._parts.hostname || this.is('IP')) { return ''; } // if hostname consists of 1 or 2 segments, it must be the domain var t = this._parts.hostname.match(/\./g); if (t && t.length < 2) { return this._parts.hostname; } // grab tld and add another segment var end = this._parts.hostname.length - this.tld(build).length - 1; end = this._parts.hostname.lastIndexOf('.', end -1) + 1; return this._parts.hostname.substring(end) || ''; } else { if (!v) { throw new TypeError('cannot set domain empty'); } URI.ensureValidHostname(v); if (!this._parts.hostname || this.is('IP')) { this._parts.hostname = v; } else { var replace = new RegExp(escapeRegEx(this.domain()) + '$'); this._parts.hostname = this._parts.hostname.replace(replace, v); } this.build(!build); return this; } }; p.tld = function(v, build) { if (this._parts.urn) { return v === undefined ? '' : this; } if (typeof v === 'boolean') { build = v; v = undefined; } // return "org" from "www.example.org" if (v === undefined) { if (!this._parts.hostname || this.is('IP')) { return ''; } var pos = this._parts.hostname.lastIndexOf('.'); var tld = this._parts.hostname.substring(pos + 1); if (build !== true && SLD && SLD.list[tld.toLowerCase()]) { return SLD.get(this._parts.hostname) || tld; } return tld; } else { var replace; if (!v) { throw new TypeError('cannot set TLD empty'); } else if (v.match(/[^a-zA-Z0-9-]/)) { if (SLD && SLD.is(v)) { replace = new RegExp(escapeRegEx(this.tld()) + '$'); this._parts.hostname = this._parts.hostname.replace(replace, v); } else { throw new TypeError('TLD "' + v + '" contains characters other than [A-Z0-9]'); } } else if (!this._parts.hostname || this.is('IP')) { throw new ReferenceError('cannot set TLD on non-domain host'); } else { replace = new RegExp(escapeRegEx(this.tld()) + '$'); this._parts.hostname = this._parts.hostname.replace(replace, v); } this.build(!build); return this; } }; p.directory = function(v, build) { if (this._parts.urn) { return v === undefined ? '' : this; } if (v === undefined || v === true) { if (!this._parts.path && !this._parts.hostname) { return ''; } if (this._parts.path === '/') { return '/'; } var end = this._parts.path.length - this.filename().length - 1; var res = this._parts.path.substring(0, end) || (this._parts.hostname ? '/' : ''); return v ? URI.decodePath(res) : res; } else { var e = this._parts.path.length - this.filename().length; var directory = this._parts.path.substring(0, e); var replace = new RegExp('^' + escapeRegEx(directory)); // fully qualifier directories begin with a slash if (!this.is('relative')) { if (!v) { v = '/'; } if (v.charAt(0) !== '/') { v = '/' + v; } } // directories always end with a slash if (v && v.charAt(v.length - 1) !== '/') { v += '/'; } v = URI.recodePath(v); this._parts.path = this._parts.path.replace(replace, v); this.build(!build); return this; } }; p.filename = function(v, build) { if (this._parts.urn) { return v === undefined ? '' : this; } if (v === undefined || v === true) { if (!this._parts.path || this._parts.path === '/') { return ''; } var pos = this._parts.path.lastIndexOf('/'); var res = this._parts.path.substring(pos+1); return v ? URI.decodePathSegment(res) : res; } else { var mutatedDirectory = false; if (v.charAt(0) === '/') { v = v.substring(1); } if (v.match(/\.?\//)) { mutatedDirectory = true; } var replace = new RegExp(escapeRegEx(this.filename()) + '$'); v = URI.recodePath(v); this._parts.path = this._parts.path.replace(replace, v); if (mutatedDirectory) { this.normalizePath(build); } else { this.build(!build); } return this; } }; p.suffix = function(v, build) { if (this._parts.urn) { return v === undefined ? '' : this; } if (v === undefined || v === true) { if (!this._parts.path || this._parts.path === '/') { return ''; } var filename = this.filename(); var pos = filename.lastIndexOf('.'); var s, res; if (pos === -1) { return ''; } // suffix may only contain alnum characters (yup, I made this up.) s = filename.substring(pos+1); res = (/^[a-z0-9%]+$/i).test(s) ? s : ''; return v ? URI.decodePathSegment(res) : res; } else { if (v.charAt(0) === '.') { v = v.substring(1); } var suffix = this.suffix(); var replace; if (!suffix) { if (!v) { return this; } this._parts.path += '.' + URI.recodePath(v); } else if (!v) { replace = new RegExp(escapeRegEx('.' + suffix) + '$'); } else { replace = new RegExp(escapeRegEx(suffix) + '$'); } if (replace) { v = URI.recodePath(v); this._parts.path = this._parts.path.replace(replace, v); } this.build(!build); return this; } }; p.segment = function(segment, v, build) { var separator = this._parts.urn ? ':' : '/'; var path = this.path(); var absolute = path.substring(0, 1) === '/'; var segments = path.split(separator); if (segment !== undefined && typeof segment !== 'number') { build = v; v = segment; segment = undefined; } if (segment !== undefined && typeof segment !== 'number') { throw new Error('Bad segment "' + segment + '", must be 0-based integer'); } if (absolute) { segments.shift(); } if (segment < 0) { // allow negative indexes to address from the end segment = Math.max(segments.length + segment, 0); } if (v === undefined) { /*jshint laxbreak: true */ return segment === undefined ? segments : segments[segment]; /*jshint laxbreak: false */ } else if (segment === null || segments[segment] === undefined) { if (isArray(v)) { segments = []; // collapse empty elements within array for (var i=0, l=v.length; i < l; i++) { if (!v[i].length && (!segments.length || !segments[segments.length -1].length)) { continue; } if (segments.length && !segments[segments.length -1].length) { segments.pop(); } segments.push(trimSlashes(v[i])); } } else if (v || typeof v === 'string') { v = trimSlashes(v); if (segments[segments.length -1] === '') { // empty trailing elements have to be overwritten // to prevent results such as /foo//bar segments[segments.length -1] = v; } else { segments.push(v); } } } else { if (v) { segments[segment] = trimSlashes(v); } else { segments.splice(segment, 1); } } if (absolute) { segments.unshift(''); } return this.path(segments.join(separator), build); }; p.segmentCoded = function(segment, v, build) { var segments, i, l; if (typeof segment !== 'number') { build = v; v = segment; segment = undefined; } if (v === undefined) { segments = this.segment(segment, v, build); if (!isArray(segments)) { segments = segments !== undefined ? URI.decode(segments) : undefined; } else { for (i = 0, l = segments.length; i < l; i++) { segments[i] = URI.decode(segments[i]); } } return segments; } if (!isArray(v)) { v = (typeof v === 'string' || v instanceof String) ? URI.encode(v) : v; } else { for (i = 0, l = v.length; i < l; i++) { v[i] = URI.encode(v[i]); } } return this.segment(segment, v, build); }; // mutating query string var q = p.query; p.query = function(v, build) { if (v === true) { return URI.parseQuery(this._parts.query, this._parts.escapeQuerySpace); } else if (typeof v === 'function') { var data = URI.parseQuery(this._parts.query, this._parts.escapeQuerySpace); var result = v.call(this, data); this._parts.query = URI.buildQuery(result || data, this._parts.duplicateQueryParameters, this._parts.escapeQuerySpace); this.build(!build); return this; } else if (v !== undefined && typeof v !== 'string') { this._parts.query = URI.buildQuery(v, this._parts.duplicateQueryParameters, this._parts.escapeQuerySpace); this.build(!build); return this; } else { return q.call(this, v, build); } }; p.setQuery = function(name, value, build) { var data = URI.parseQuery(this._parts.query, this._parts.escapeQuerySpace); if (typeof name === 'string' || name instanceof String) { data[name] = value !== undefined ? value : null; } else if (typeof name === 'object') { for (var key in name) { if (hasOwn.call(name, key)) { data[key] = name[key]; } } } else { throw new TypeError('URI.addQuery() accepts an object, string as the name parameter'); } this._parts.query = URI.buildQuery(data, this._parts.duplicateQueryParameters, this._parts.escapeQuerySpace); if (typeof name !== 'string') { build = value; } this.build(!build); return this; }; p.addQuery = function(name, value, build) { var data = URI.parseQuery(this._parts.query, this._parts.escapeQuerySpace); URI.addQuery(data, name, value === undefined ? null : value); this._parts.query = URI.buildQuery(data, this._parts.duplicateQueryParameters, this._parts.escapeQuerySpace); if (typeof name !== 'string') { build = value; } this.build(!build); return this; }; p.removeQuery = function(name, value, build) { var data = URI.parseQuery(this._parts.query, this._parts.escapeQuerySpace); URI.removeQuery(data, name, value); this._parts.query = URI.buildQuery(data, this._parts.duplicateQueryParameters, this._parts.escapeQuerySpace); if (typeof name !== 'string') { build = value; } this.build(!build); return this; }; p.hasQuery = function(name, value, withinArray) { var data = URI.parseQuery(this._parts.query, this._parts.escapeQuerySpace); return URI.hasQuery(data, name, value, withinArray); }; p.setSearch = p.setQuery; p.addSearch = p.addQuery; p.removeSearch = p.removeQuery; p.hasSearch = p.hasQuery; // sanitizing URLs p.normalize = function() { if (this._parts.urn) { return this .normalizeProtocol(false) .normalizePath(false) .normalizeQuery(false) .normalizeFragment(false) .build(); } return this .normalizeProtocol(false) .normalizeHostname(false) .normalizePort(false) .normalizePath(false) .normalizeQuery(false) .normalizeFragment(false) .build(); }; p.normalizeProtocol = function(build) { if (typeof this._parts.protocol === 'string') { this._parts.protocol = this._parts.protocol.toLowerCase(); this.build(!build); } return this; }; p.normalizeHostname = function(build) { if (this._parts.hostname) { if (this.is('IDN') && punycode) { this._parts.hostname = punycode.toASCII(this._parts.hostname); } else if (this.is('IPv6') && IPv6) { this._parts.hostname = IPv6.best(this._parts.hostname); } this._parts.hostname = this._parts.hostname.toLowerCase(); this.build(!build); } return this; }; p.normalizePort = function(build) { // remove port of it's the protocol's default if (typeof this._parts.protocol === 'string' && this._parts.port === URI.defaultPorts[this._parts.protocol]) { this._parts.port = null; this.build(!build); } return this; }; p.normalizePath = function(build) { var _path = this._parts.path; if (!_path) { return this; } if (this._parts.urn) { this._parts.path = URI.recodeUrnPath(this._parts.path); this.build(!build); return this; } if (this._parts.path === '/') { return this; } _path = URI.recodePath(_path); var _was_relative; var _leadingParents = ''; var _parent, _pos; // handle relative paths if (_path.charAt(0) !== '/') { _was_relative = true; _path = '/' + _path; } // handle relative files (as opposed to directories) if (_path.slice(-3) === '/..' || _path.slice(-2) === '/.') { _path += '/'; } // resolve simples _path = _path .replace(/(\/(\.\/)+)|(\/\.$)/g, '/') .replace(/\/{2,}/g, '/'); // remember leading parents if (_was_relative) { _leadingParents = _path.substring(1).match(/^(\.\.\/)+/) || ''; if (_leadingParents) { _leadingParents = _leadingParents[0]; } } // resolve parents while (true) { _parent = _path.search(/\/\.\.(\/|$)/); if (_parent === -1) { // no more ../ to resolve break; } else if (_parent === 0) { // top level cannot be relative, skip it _path = _path.substring(3); continue; } _pos = _path.substring(0, _parent).lastIndexOf('/'); if (_pos === -1) { _pos = _parent; } _path = _path.substring(0, _pos) + _path.substring(_parent + 3); } // revert to relative if (_was_relative && this.is('relative')) { _path = _leadingParents + _path.substring(1); } this._parts.path = _path; this.build(!build); return this; }; p.normalizePathname = p.normalizePath; p.normalizeQuery = function(build) { if (typeof this._parts.query === 'string') { if (!this._parts.query.length) { this._parts.query = null; } else { this.query(URI.parseQuery(this._parts.query, this._parts.escapeQuerySpace)); } this.build(!build); } return this; }; p.normalizeFragment = function(build) { if (!this._parts.fragment) { this._parts.fragment = null; this.build(!build); } return this; }; p.normalizeSearch = p.normalizeQuery; p.normalizeHash = p.normalizeFragment; p.iso8859 = function() { // expect unicode input, iso8859 output var e = URI.encode; var d = URI.decode; URI.encode = escape; URI.decode = decodeURIComponent; try { this.normalize(); } finally { URI.encode = e; URI.decode = d; } return this; }; p.unicode = function() { // expect iso8859 input, unicode output var e = URI.encode; var d = URI.decode; URI.encode = strictEncodeURIComponent; URI.decode = unescape; try { this.normalize(); } finally { URI.encode = e; URI.decode = d; } return this; }; p.readable = function() { var uri = this.clone(); // removing username, password, because they shouldn't be displayed according to RFC 3986 uri.username('').password('').normalize(); var t = ''; if (uri._parts.protocol) { t += uri._parts.protocol + '://'; } if (uri._parts.hostname) { if (uri.is('punycode') && punycode) { t += punycode.toUnicode(uri._parts.hostname); if (uri._parts.port) { t += ':' + uri._parts.port; } } else { t += uri.host(); } } if (uri._parts.hostname && uri._parts.path && uri._parts.path.charAt(0) !== '/') { t += '/'; } t += uri.path(true); if (uri._parts.query) { var q = ''; for (var i = 0, qp = uri._parts.query.split('&'), l = qp.length; i < l; i++) { var kv = (qp[i] || '').split('='); q += '&' + URI.decodeQuery(kv[0], this._parts.escapeQuerySpace) .replace(/&/g, '%26'); if (kv[1] !== undefined) { q += '=' + URI.decodeQuery(kv[1], this._parts.escapeQuerySpace) .replace(/&/g, '%26'); } } t += '?' + q.substring(1); } t += URI.decodeQuery(uri.hash(), true); return t; }; // resolving relative and absolute URLs p.absoluteTo = function(base) { var resolved = this.clone(); var properties = ['protocol', 'username', 'password', 'hostname', 'port']; var basedir, i, p; if (this._parts.urn) { throw new Error('URNs do not have any generally defined hierarchical components'); } if (!(base instanceof URI)) { base = new URI(base); } if (!resolved._parts.protocol) { resolved._parts.protocol = base._parts.protocol; } if (this._parts.hostname) { return resolved; } for (i = 0; (p = properties[i]); i++) { resolved._parts[p] = base._parts[p]; } if (!resolved._parts.path) { resolved._parts.path = base._parts.path; if (!resolved._parts.query) { resolved._parts.query = base._parts.query; } } else { if (resolved._parts.path.substring(-2) === '..') { resolved._parts.path += '/'; } if (resolved.path().charAt(0) !== '/') { basedir = base.directory(); basedir = basedir ? basedir : base.path().indexOf('/') === 0 ? '/' : ''; resolved._parts.path = (basedir ? (basedir + '/') : '') + resolved._parts.path; resolved.normalizePath(); } } resolved.build(); return resolved; }; p.relativeTo = function(base) { var relative = this.clone().normalize(); var relativeParts, baseParts, common, relativePath, basePath; if (relative._parts.urn) { throw new Error('URNs do not have any generally defined hierarchical components'); } base = new URI(base).normalize(); relativeParts = relative._parts; baseParts = base._parts; relativePath = relative.path(); basePath = base.path(); if (relativePath.charAt(0) !== '/') { throw new Error('URI is already relative'); } if (basePath.charAt(0) !== '/') { throw new Error('Cannot calculate a URI relative to another relative URI'); } if (relativeParts.protocol === baseParts.protocol) { relativeParts.protocol = null; } if (relativeParts.username !== baseParts.username || relativeParts.password !== baseParts.password) { return relative.build(); } if (relativeParts.protocol !== null || relativeParts.username !== null || relativeParts.password !== null) { return relative.build(); } if (relativeParts.hostname === baseParts.hostname && relativeParts.port === baseParts.port) { relativeParts.hostname = null; relativeParts.port = null; } else { return relative.build(); } if (relativePath === basePath) { relativeParts.path = ''; return relative.build(); } // determine common sub path common = URI.commonPath(relativePath, basePath); // If the paths have nothing in common, return a relative URL with the absolute path. if (!common) { return relative.build(); } var parents = baseParts.path .substring(common.length) .replace(/[^\/]*$/, '') .replace(/.*?\//g, '../'); relativeParts.path = (parents + relativeParts.path.substring(common.length)) || './'; return relative.build(); }; // comparing URIs p.equals = function(uri) { var one = this.clone(); var two = new URI(uri); var one_map = {}; var two_map = {}; var checked = {}; var one_query, two_query, key; one.normalize(); two.normalize(); // exact match if (one.toString() === two.toString()) { return true; } // extract query string one_query = one.query(); two_query = two.query(); one.query(''); two.query(''); // definitely not equal if not even non-query parts match if (one.toString() !== two.toString()) { return false; } // query parameters have the same length, even if they're permuted if (one_query.length !== two_query.length) { return false; } one_map = URI.parseQuery(one_query, this._parts.escapeQuerySpace); two_map = URI.parseQuery(two_query, this._parts.escapeQuerySpace); for (key in one_map) { if (hasOwn.call(one_map, key)) { if (!isArray(one_map[key])) { if (one_map[key] !== two_map[key]) { return false; } } else if (!arraysEqual(one_map[key], two_map[key])) { return false; } checked[key] = true; } } for (key in two_map) { if (hasOwn.call(two_map, key)) { if (!checked[key]) { // two contains a parameter not present in one return false; } } } return true; }; // state p.duplicateQueryParameters = function(v) { this._parts.duplicateQueryParameters = !!v; return this; }; p.escapeQuerySpace = function(v) { this._parts.escapeQuerySpace = !!v; return this; }; return URI; })); ================================================ FILE: luigi/static/visualiser/lib/mustache.js ================================================ /*! * mustache.js - Logic-less {{mustache}} templates with JavaScript * http://github.com/janl/mustache.js */ /*global define: false*/ (function (root, factory) { if (typeof exports === "object" && exports) { factory(exports); // CommonJS } else { var mustache = {}; factory(mustache); if (typeof define === "function" && define.amd) { define(mustache); // AMD } else { root.Mustache = mustache; // ================================================ FILE: luigi/target.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ The abstract :py:class:`Target` class. It is a central concept of Luigi and represents the state of the workflow. """ import abc import io import logging import os import random import tempfile import warnings from contextlib import contextmanager logger = logging.getLogger("luigi-interface") class Target(metaclass=abc.ABCMeta): """ A Target is a resource generated by a :py:class:`~luigi.task.Task`. For example, a Target might correspond to a file in HDFS or data in a database. The Target interface defines one method that must be overridden: :py:meth:`exists`, which signifies if the Target has been created or not. Typically, a :py:class:`~luigi.task.Task` will define one or more Targets as output, and the Task is considered complete if and only if each of its output Targets exist. """ @abc.abstractmethod def exists(self): """ Returns ``True`` if the :py:class:`Target` exists and ``False`` otherwise. """ pass class FileSystemException(Exception): """ Base class for generic file system exceptions. """ pass class FileAlreadyExists(FileSystemException): """ Raised when a file system operation can't be performed because a directory exists but is required to not exist. """ pass class MissingParentDirectory(FileSystemException): """ Raised when a parent directory doesn't exist. (Imagine mkdir without -p) """ pass class NotADirectory(FileSystemException): """ Raised when a file system operation can't be performed because an expected directory is actually a file. """ pass class FileSystem(metaclass=abc.ABCMeta): """ FileSystem abstraction used in conjunction with :py:class:`FileSystemTarget`. Typically, a FileSystem is associated with instances of a :py:class:`FileSystemTarget`. The instances of the :py:class:`FileSystemTarget` will delegate methods such as :py:meth:`FileSystemTarget.exists` and :py:meth:`FileSystemTarget.remove` to the FileSystem. Methods of FileSystem raise :py:class:`FileSystemException` if there is a problem completing the operation. """ @abc.abstractmethod def exists(self, path): """ Return ``True`` if file or directory at ``path`` exist, ``False`` otherwise :param str path: a path within the FileSystem to check for existence. """ pass @abc.abstractmethod def remove(self, path, recursive=True, skip_trash=True): """Remove file or directory at location ``path`` :param str path: a path within the FileSystem to remove. :param bool recursive: if the path is a directory, recursively remove the directory and all of its descendants. Defaults to ``True``. """ pass def mkdir(self, path, parents=True, raise_if_exists=False): """ Create directory at location ``path`` Creates the directory at ``path`` and implicitly create parent directories if they do not already exist. :param str path: a path within the FileSystem to create as a directory. :param bool parents: Create parent directories when necessary. When parents=False and the parent directory doesn't exist, raise luigi.target.MissingParentDirectory :param bool raise_if_exists: raise luigi.target.FileAlreadyExists if the folder already exists. """ raise NotImplementedError("mkdir() not implemented on {0}".format(self.__class__.__name__)) def isdir(self, path): """ Return ``True`` if the location at ``path`` is a directory. If not, return ``False``. :param str path: a path within the FileSystem to check as a directory. *Note*: This method is optional, not all FileSystem subclasses implements it. """ raise NotImplementedError("isdir() not implemented on {0}".format(self.__class__.__name__)) def listdir(self, path): """Return a list of files rooted in path. This returns an iterable of the files rooted at ``path``. This is intended to be a recursive listing. :param str path: a path within the FileSystem to list. *Note*: This method is optional, not all FileSystem subclasses implements it. """ raise NotImplementedError("listdir() not implemented on {0}".format(self.__class__.__name__)) def move(self, path, dest): """ Move a file, as one would expect. """ raise NotImplementedError("move() not implemented on {0}".format(self.__class__.__name__)) def rename_dont_move(self, path, dest): """ Potentially rename ``path`` to ``dest``, but don't move it into the ``dest`` folder (if it is a folder). This relates to :ref:`AtomicWrites`. This method has a reasonable but not bullet proof default implementation. It will just do ``move()`` if the file doesn't ``exists()`` already. """ warnings.warn("File system {} client doesn't support atomic mv.".format(self.__class__.__name__)) if self.exists(dest): raise FileAlreadyExists() self.move(path, dest) def rename(self, *args, **kwargs): """ Alias for ``move()`` """ self.move(*args, **kwargs) def copy(self, path, dest): """ Copy a file or a directory with contents. Currently, LocalFileSystem and MockFileSystem support only single file copying but S3Client copies either a file or a directory as required. """ raise NotImplementedError("copy() not implemented on {0}".format(self.__class__.__name__)) class FileSystemTarget(Target): """ Base class for FileSystem Targets like :class:`~luigi.local_target.LocalTarget` and :class:`~luigi.contrib.hdfs.HdfsTarget`. A FileSystemTarget has an associated :py:class:`FileSystem` to which certain operations can be delegated. By default, :py:meth:`exists` and :py:meth:`remove` are delegated to the :py:class:`FileSystem`, which is determined by the :py:attr:`fs` property. Methods of FileSystemTarget raise :py:class:`FileSystemException` if there is a problem completing the operation. Usage: .. code-block:: python target = FileSystemTarget('~/some_file.txt') target = FileSystemTarget(pathlib.Path('~') / 'some_file.txt') target.exists() # False """ def __init__(self, path): """ Initializes a FileSystemTarget instance. :param path: the path associated with this FileSystemTarget. """ # cast to str to allow path to be objects like pathlib.PosixPath and py._path.local.LocalPath self.path = str(path) def __str__(self): return self.path @property @abc.abstractmethod def fs(self): """ The :py:class:`FileSystem` associated with this FileSystemTarget. """ raise NotImplementedError() @abc.abstractmethod def open(self, mode): """ Open the FileSystem target. This method returns a file-like object which can either be read from or written to depending on the specified mode. :param str mode: the mode `r` opens the FileSystemTarget in read-only mode, whereas `w` will open the FileSystemTarget in write mode. Subclasses can implement additional options. Using `b` is not supported; initialize with `format=Nop` instead. """ pass def exists(self): """ Returns ``True`` if the path for this FileSystemTarget exists; ``False`` otherwise. This method is implemented by using :py:attr:`fs`. """ path = self.path if "*" in path or "?" in path or "[" in path or "{" in path: logger.warning("Using wildcards in path %s might lead to processing of an incomplete dataset; override exists() to suppress the warning.", path) return self.fs.exists(path) def remove(self): """ Remove the resource at the path specified by this FileSystemTarget. This method is implemented by using :py:attr:`fs`. """ self.fs.remove(self.path) @contextmanager def temporary_path(self): """ A context manager that enables a reasonably short, general and magic-less way to solve the :ref:`AtomicWrites`. * On *entering*, it will create the parent directories so the temporary_path is writeable right away. This step uses :py:meth:`FileSystem.mkdir`. * On *exiting*, it will move the temporary file if there was no exception thrown. This step uses :py:meth:`FileSystem.rename_dont_move` The file system operations will be carried out by calling them on :py:attr:`fs`. The typical use case looks like this: .. code:: python class MyTask(luigi.Task): def output(self): return MyFileSystemTarget(...) def run(self): with self.output().temporary_path() as self.temp_output_path: run_some_external_command(output_path=self.temp_output_path) """ num = random.randrange(0, 10_000_000_000) slashless_path = self.path.rstrip("/").rstrip("\\") _temp_path = "{}-luigi-tmp-{:010}{}".format(slashless_path, num, self._trailing_slash()) # TODO: os.path doesn't make sense here as it's os-dependent tmp_dir = os.path.dirname(slashless_path) if tmp_dir: self.fs.mkdir(tmp_dir, parents=True, raise_if_exists=False) yield _temp_path # We won't reach here if there was an user exception. self.fs.rename_dont_move(_temp_path, self.path) def _touchz(self): with self.open("w"): pass def _trailing_slash(self): # I suppose one day schema-like paths, like # file:///path/blah.txt?params=etc can be parsed too return self.path[-1] if self.path[-1] in r"\/" else "" class AtomicLocalFile(io.BufferedWriter): """Abstract class to create a Target that creates a temporary file in the local filesystem before moving it to its final destination. This class is just for the writing part of the Target. See :class:`luigi.local_target.LocalTarget` for example """ def __init__(self, path): self.__tmp_path = self.generate_tmp_path(path) self.path = path super(AtomicLocalFile, self).__init__(io.FileIO(self.__tmp_path, "w")) def close(self): super(AtomicLocalFile, self).close() self.move_to_final_destination() def generate_tmp_path(self, path): return os.path.join(tempfile.gettempdir(), "luigi-s3-tmp-%09d" % random.randrange(0, 10_000_000_000)) def move_to_final_destination(self): raise NotImplementedError() def __del__(self): if os.path.exists(self.tmp_path): os.remove(self.tmp_path) @property def tmp_path(self): return self.__tmp_path def __exit__(self, exc_type, exc, traceback): "Close/commit the file if there are no exception" if exc_type: return return super(AtomicLocalFile, self).__exit__(exc_type, exc, traceback) ================================================ FILE: luigi/task.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ The abstract :py:class:`Task` class. It is a central concept of Luigi and represents the state of the workflow. See :doc:`/tasks` for an overview. """ import copy import functools import hashlib import json import logging import re import traceback import warnings from collections import OrderedDict, deque from contextlib import contextmanager from typing import Any, Dict, Optional from typing_extensions import dataclass_transform import luigi from luigi import configuration, parameter from luigi.parameter import ParameterVisibility, UnconsumedParameterWarning from luigi.task_register import Register Parameter = parameter.Parameter logger = logging.getLogger("luigi-interface") TASK_ID_INCLUDE_PARAMS = 3 TASK_ID_TRUNCATE_PARAMS = 16 TASK_ID_TRUNCATE_HASH = 10 TASK_ID_INVALID_CHAR_REGEX = re.compile(r"[^A-Za-z0-9_]") _SAME_AS_PYTHON_MODULE = "_same_as_python_module" def namespace(namespace=None, scope=""): """ Call to set namespace of tasks declared after the call. It is often desired to call this function with the keyword argument ``scope=__name__``. The ``scope`` keyword makes it so that this call is only effective for task classes with a matching [*]_ ``__module__``. The default value for ``scope`` is the empty string, which means all classes. Multiple calls with the same scope simply replace each other. The namespace of a :py:class:`Task` can also be changed by specifying the property ``task_namespace``. .. code-block:: python class Task2(luigi.Task): task_namespace = 'namespace2' This explicit setting takes priority over whatever is set in the ``namespace()`` method, and it's also inherited through normal python inheritence. There's no equivalent way to set the ``task_family``. *New since Luigi 2.6.0:* ``scope`` keyword argument. .. [*] When there are multiple levels of matching module scopes like ``a.b`` vs ``a.b.c``, the more specific one (``a.b.c``) wins. .. seealso:: The new and better scaling :py:func:`auto_namespace` """ Register._default_namespace_dict[scope] = namespace or "" def auto_namespace(scope=""): """ Same as :py:func:`namespace`, but instead of a constant namespace, it will be set to the ``__module__`` of the task class. This is desirable for these reasons: * Two tasks with the same name will not have conflicting task families * It's more pythonic, as modules are Python's recommended way to do namespacing. * It's traceable. When you see the full name of a task, you can immediately identify where it is defined. We recommend calling this function from your package's outermost ``__init__.py`` file. The file contents could look like this: .. code-block:: python import luigi luigi.auto_namespace(scope=__name__) To reset an ``auto_namespace()`` call, you can use ``namespace(scope='my_scope')``. But this will not be needed (and is also discouraged) if you use the ``scope`` kwarg. *New since Luigi 2.6.0.* """ namespace(namespace=_SAME_AS_PYTHON_MODULE, scope=scope) def task_id_str(task_family, params): """ Returns a canonical string used to identify a particular task :param task_family: The task family (class name) of the task :param params: a dict mapping parameter names to their serialized values :return: A unique, shortened identifier corresponding to the family and params """ # task_id is a concatenation of task family, the first values of the first 3 parameters # sorted by parameter name and a md5hash of the family/parameters as a cananocalised json. param_str = json.dumps(params, separators=(",", ":"), sort_keys=True) param_hash = hashlib.new("md5", param_str.encode("utf-8"), usedforsecurity=False).hexdigest() param_summary = "_".join(p[:TASK_ID_TRUNCATE_PARAMS] for p in (params[p] for p in sorted(params)[:TASK_ID_INCLUDE_PARAMS])) param_summary = TASK_ID_INVALID_CHAR_REGEX.sub("_", param_summary) return "{}_{}_{}".format(task_family, param_summary, param_hash[:TASK_ID_TRUNCATE_HASH]) class BulkCompleteNotImplementedError(NotImplementedError): """This is here to trick pylint. pylint thinks anything raising NotImplementedError needs to be implemented in any subclass. bulk_complete isn't like that. This tricks pylint into thinking that the default implementation is a valid implementation and not an abstract method.""" pass @dataclass_transform(eq_default=False, order_default=False, kw_only_default=True, field_specifiers=(Parameter,)) class Task(metaclass=Register): """ This is the base class of all Luigi Tasks, the base unit of work in Luigi. A Luigi Task describes a unit or work. The key methods of a Task, which must be implemented in a subclass are: * :py:meth:`run` - the computation done by this task. * :py:meth:`requires` - the list of Tasks that this Task depends on. * :py:meth:`output` - the output :py:class:`Target` that this Task creates. Each :py:class:`~luigi.Parameter` of the Task should be declared as members: .. code:: python class MyTask(luigi.Task): count = luigi.IntParameter() second_param = luigi.Parameter() In addition to any declared properties and methods, there are a few non-declared properties, which are created by the :py:class:`Register` metaclass: """ _event_callbacks: Dict[Any, Any] = {} #: Priority of the task: the scheduler should favor available #: tasks with higher priority values first. #: See :ref:`Task.priority` priority = 0 disabled = False #: Resources used by the task. Should be formatted like {"scp": 1} to indicate that the #: task requires 1 unit of the scp resource. resources: Dict[str, Any] = {} #: Number of seconds after which to time out the run function. #: No timeout if set to 0. #: Defaults to 0 or worker-timeout value in config worker_timeout: Optional[int] = None #: Maximum number of tasks to run together as a batch. Infinite by default max_batch_size = float("inf") @property def batchable(self): """ True if this instance can be run as part of a batch. By default, True if it has any batched parameters """ return bool(self.batch_param_names()) @property def retry_count(self): """ Override this positive integer to have different ``retry_count`` at task level Check :ref:`scheduler-config` """ return None @property def disable_hard_timeout(self): """ Override this positive integer to have different ``disable_hard_timeout`` at task level. Check :ref:`scheduler-config` """ return None @property def disable_window(self): """ Override this positive integer to have different ``disable_window`` at task level. Check :ref:`scheduler-config` """ return None @property def disable_window_seconds(self): warnings.warn("Use of `disable_window_seconds` has been deprecated, use `disable_window` instead", DeprecationWarning) return self.disable_window @property def owner_email(self): """ Override this to send out additional error emails to task owner, in addition to the one defined in the global configuration. This should return a string or a list of strings. e.g. 'test@exmaple.com' or ['test1@example.com', 'test2@example.com'] """ return None def _owner_list(self): """ Turns the owner_email property into a list. This should not be overridden. """ owner_email = self.owner_email if owner_email is None: return [] elif isinstance(owner_email, str): return owner_email.split(",") else: return owner_email @property def use_cmdline_section(self): """Property used by core config such as `--workers` etc. These will be exposed without the class as prefix.""" return True @classmethod def event_handler(cls, event): """ Decorator for adding event handlers. """ def wrapped(callback): cls._event_callbacks.setdefault(cls, {}).setdefault(event, set()).add(callback) return callback return wrapped @classmethod def remove_event_handler(cls, event, callback): """ Function to remove the event handler registered previously by the cls.event_handler decorator. """ cls._event_callbacks[cls][event].remove(callback) def trigger_event(self, event, *args, **kwargs): """ Trigger that calls all of the specified events associated with this class. """ for event_class, event_callbacks in self._event_callbacks.items(): if not isinstance(self, event_class): continue for callback in event_callbacks.get(event, []): try: # callbacks are protected callback(*args, **kwargs) except KeyboardInterrupt: return except BaseException: logger.exception("Error in event callback for %r", event) @property def accepts_messages(self): """ For configuring which scheduler messages can be received. When falsy, this tasks does not accept any message. When True, all messages are accepted. """ return False @property def task_module(self): """Returns what Python module to import to get access to this class.""" # TODO(erikbern): we should think about a language-agnostic mechanism return self.__class__.__module__ _visible_in_registry = True # TODO: Consider using in luigi.util as well __not_user_specified = "__not_user_specified" # This is here just to help pylint, the Register metaclass will always set # this value anyway. _namespace_at_class_time = None task_namespace = __not_user_specified """ This value can be overridden to set the namespace that will be used. (See :ref:`Task.namespaces_famlies_and_ids`) If it's not specified and you try to read this value anyway, it will return garbage. Please use :py:meth:`get_task_namespace` to read the namespace. Note that setting this value with ``@property`` will not work, because this is a class level value. """ @classmethod def get_task_namespace(cls): """ The task family for the given class. Note: You normally don't want to override this. """ if cls.task_namespace != cls.__not_user_specified: return cls.task_namespace elif cls._namespace_at_class_time == _SAME_AS_PYTHON_MODULE: return cls.__module__ return cls._namespace_at_class_time @property def task_family(self): """ DEPRECATED since after 2.4.0. See :py:meth:`get_task_family` instead. Hopefully there will be less meta magic in Luigi. Convenience method since a property on the metaclass isn't directly accessible through the class instances. """ return self.__class__.task_family @classmethod def get_task_family(cls): """ The task family for the given class. If ``task_namespace`` is not set, then it's simply the name of the class. Otherwise, ``.`` is prefixed to the class name. Note: You normally don't want to override this. """ if not cls.get_task_namespace(): return cls.__name__ else: return "{}.{}".format(cls.get_task_namespace(), cls.__name__) @classmethod def get_params(cls): """ Returns all of the Parameters for this Task. """ # We want to do this here and not at class instantiation, or else there is no room to extend classes dynamically params = [] for param_name in dir(cls): param_obj = getattr(cls, param_name) if not isinstance(param_obj, Parameter): continue params.append((param_name, param_obj)) # The order the parameters are created matters. See Parameter class params.sort(key=lambda t: t[1]._counter) return params @classmethod def batch_param_names(cls): return [name for name, p in cls.get_params() if p._is_batchable()] @classmethod def get_param_names(cls, include_significant=False): return [name for name, p in cls.get_params() if include_significant or p.significant] @classmethod def get_param_values(cls, params, args, kwargs): """ Get the values of the parameters from the args and kwargs. :param params: list of (param_name, Parameter). :param args: positional arguments :param kwargs: keyword arguments. :returns: list of `(name, value)` tuples, one for each parameter. """ result = {} params_dict = dict(params) task_family = cls.get_task_family() # In case any exceptions are thrown, create a helpful description of how the Task was invoked # TODO: should we detect non-reprable arguments? These will lead to mysterious errors exc_desc = "%s[args=%s, kwargs=%s]" % (task_family, args, kwargs) # Fill in the positional arguments positional_params = [(n, p) for n, p in params if p.positional] for i, arg in enumerate(args): if i >= len(positional_params): raise parameter.UnknownParameterException("%s: takes at most %d parameters (%d given)" % (exc_desc, len(positional_params), len(args))) param_name, param_obj = positional_params[i] result[param_name] = param_obj.normalize(arg) # Then the keyword arguments for param_name, arg in kwargs.items(): if param_name in result: raise parameter.DuplicateParameterException("%s: parameter %s was already set as a positional parameter" % (exc_desc, param_name)) if param_name not in params_dict: raise parameter.UnknownParameterException("%s: unknown parameter %s" % (exc_desc, param_name)) result[param_name] = params_dict[param_name].normalize(arg) # Then use the defaults for anything not filled in for param_name, param_obj in params: if param_name not in result: try: has_task_value = param_obj.has_task_value(task_family, param_name) except Exception as exc: raise ValueError("%s: Error when parsing the default value of '%s'" % (exc_desc, param_name)) from exc if not has_task_value: raise parameter.MissingParameterException("%s: requires the '%s' parameter to be set" % (exc_desc, param_name)) result[param_name] = param_obj.task_value(task_family, param_name) def list_to_tuple(x): """Make tuples out of lists and sets to allow hashing""" if isinstance(x, list) or isinstance(x, set): return tuple(x) else: return x # Check for unconsumed parameters conf = configuration.get_config() if not hasattr(cls, "_unconsumed_params"): cls._unconsumed_params = set() if task_family in conf.sections(): ignore_unconsumed = getattr(cls, "ignore_unconsumed", set()) for key, value in conf[task_family].items(): key = key.replace("-", "_") composite_key = f"{task_family}_{key}" if key not in result and key not in ignore_unconsumed and composite_key not in cls._unconsumed_params: warnings.warn( f"The configuration contains the parameter '{key}' with value '{value}' that is not consumed by the task '{task_family}'.", UnconsumedParameterWarning, ) cls._unconsumed_params.add(composite_key) # Sort it by the correct order and make a list return [(param_name, list_to_tuple(result[param_name])) for param_name, param_obj in params] def __init__(self, *args, **kwargs): params = self.get_params() param_values = self.get_param_values(params, args, kwargs) # Set all values on class instance for key, value in param_values: setattr(self, key, value) # Register kwargs as an attribute on the class. Might be useful self.param_kwargs = dict(param_values) self._warn_on_wrong_param_types() self.task_id = task_id_str(self.get_task_family(), self.to_str_params(only_significant=True, only_public=True)) self.__hash = hash(self.task_id) self.set_tracking_url = None self.set_status_message = None self.set_progress_percentage = None @property def param_args(self): warnings.warn("Use of param_args has been deprecated.", DeprecationWarning) return tuple(self.param_kwargs[k] for k, v in self.get_params()) def initialized(self): """ Returns ``True`` if the Task is initialized and ``False`` otherwise. """ return hasattr(self, "task_id") def _warn_on_wrong_param_types(self): params = dict(self.get_params()) for param_name, param_value in self.param_kwargs.items(): params[param_name]._warn_on_wrong_param_type(param_name, param_value) @classmethod def from_str_params(cls, params_str): """ Creates an instance from a str->str hash. :param params_str: dict of param name -> value as string. """ kwargs = {} for param_name, param in cls.get_params(): if param_name in params_str: param_str = params_str[param_name] if isinstance(param_str, list): kwargs[param_name] = param._parse_list(param_str) else: kwargs[param_name] = param.parse(param_str) return cls(**kwargs) def to_str_params(self, only_significant=False, only_public=False): """ Convert all parameters to a str->str hash. """ params_str = {} params = dict(self.get_params()) for param_name, param_value in self.param_kwargs.items(): if ( ((not only_significant) or params[param_name].significant) and ((not only_public) or params[param_name].visibility == ParameterVisibility.PUBLIC) and params[param_name].visibility != ParameterVisibility.PRIVATE ): params_str[param_name] = params[param_name].serialize(param_value) return params_str def _get_param_visibilities(self): param_visibilities = {} params = dict(self.get_params()) for param_name, param_value in self.param_kwargs.items(): if params[param_name].visibility != ParameterVisibility.PRIVATE: param_visibilities[param_name] = params[param_name].visibility.serialize() return param_visibilities def clone(self, cls=None, **kwargs): """ Creates a new instance from an existing instance where some of the args have changed. There's at least two scenarios where this is useful (see test/clone_test.py): * remove a lot of boiler plate when you have recursive dependencies and lots of args * there's task inheritance and some logic is on the base class :param cls: :param kwargs: :return: """ if cls is None: cls = self.__class__ new_k = {} for param_name, param_class in cls.get_params(): if param_name in kwargs: new_k[param_name] = kwargs[param_name] elif hasattr(self, param_name): new_k[param_name] = getattr(self, param_name) return cls(**new_k) def __hash__(self): return self.__hash def __repr__(self): """ Build a task representation like `MyTask(param1=1.5, param2='5')` """ params = self.get_params() param_values = self.get_param_values(params, [], self.param_kwargs) # Build up task id repr_parts = [] param_objs = dict(params) for param_name, param_value in param_values: if param_objs[param_name].significant: repr_parts.append("%s=%s" % (param_name, param_objs[param_name].serialize(param_value))) task_str = "{}({})".format(self.get_task_family(), ", ".join(repr_parts)) return task_str def __eq__(self, other): return self.__class__ == other.__class__ and self.task_id == other.task_id def complete(self): """ If the task has any outputs, return ``True`` if all outputs exist. Otherwise, return ``False``. However, you may freely override this method with custom logic. """ outputs = flatten(self.output()) if len(outputs) == 0: warnings.warn("Task %r without outputs has no custom complete() method" % self, stacklevel=2) return False return all(map(lambda output: output.exists(), outputs)) @classmethod def bulk_complete(cls, parameter_tuples): """ Returns those of parameter_tuples for which this Task is complete. Override (with an efficient implementation) for efficient scheduling with range tools. Keep the logic consistent with that of complete(). """ raise BulkCompleteNotImplementedError() def output(self): """ The output that this Task produces. The output of the Task determines if the Task needs to be run--the task is considered finished iff the outputs all exist. Subclasses should override this method to return a single :py:class:`Target` or a list of :py:class:`Target` instances. Implementation note If running multiple workers, the output must be a resource that is accessible by all workers, such as a DFS or database. Otherwise, workers might compute the same output since they don't see the work done by other workers. See :ref:`Task.output` """ return [] # default impl def requires(self): """ The Tasks that this Task depends on. A Task will only run if all of the Tasks that it requires are completed. If your Task does not require any other Tasks, then you don't need to override this method. Otherwise, a subclass can override this method to return a single Task, a list of Task instances, or a dict whose values are Task instances. See :ref:`Task.requires` """ return [] # default impl def _requires(self): """ Override in "template" tasks which themselves are supposed to be subclassed and thus have their requires() overridden (name preserved to provide consistent end-user experience), yet need to introduce (non-input) dependencies. Must return an iterable which among others contains the _requires() of the superclass. """ return flatten(self.requires()) # base impl def process_resources(self): """ Override in "template" tasks which provide common resource functionality but allow subclasses to specify additional resources while preserving the name for consistent end-user experience. """ return self.resources # default impl def input(self): """ Returns the outputs of the Tasks returned by :py:meth:`requires` See :ref:`Task.input` :return: a list of :py:class:`Target` objects which are specified as outputs of all required Tasks. """ return getpaths(self.requires()) def deps(self): """ Internal method used by the scheduler. Returns the flattened list of requires. """ # used by scheduler return flatten(self._requires()) def run(self): """ The task run method, to be overridden in a subclass. See :ref:`Task.run` """ pass # default impl def on_failure(self, exception): """ Override for custom error handling. This method gets called if an exception is raised in :py:meth:`run`. The returned value of this method is json encoded and sent to the scheduler as the `expl` argument. Its string representation will be used as the body of the error email sent out if any. Default behavior is to return a string representation of the stack trace. """ traceback_string = traceback.format_exc() return "Runtime error:\n%s" % traceback_string def on_success(self): """ Override for doing custom completion handling for a larger class of tasks This method gets called when :py:meth:`run` completes without raising any exceptions. The returned value is json encoded and sent to the scheduler as the `expl` argument. Default behavior is to send an None value""" pass @contextmanager def no_unpicklable_properties(self): """ Remove unpicklable properties before dump task and resume them after. This method could be called in subtask's dump method, to ensure unpicklable properties won't break dump. This method is a context-manager which can be called as below: .. code-block: python class DummyTask(luigi): def _dump(self): with self.no_unpicklable_properties(): pickle.dumps(self) """ unpicklable_properties = tuple(luigi.worker.TaskProcess.forward_reporter_attributes.values()) reserved_properties = {} for property_name in unpicklable_properties: if hasattr(self, property_name): reserved_properties[property_name] = getattr(self, property_name) setattr(self, property_name, "placeholder_during_pickling") yield for property_name, value in reserved_properties.items(): setattr(self, property_name, value) class MixinNaiveBulkComplete: """ Enables a Task to be efficiently scheduled with e.g. range tools, by providing a bulk_complete implementation which checks completeness in a loop. Applicable to tasks whose completeness checking is cheap. This doesn't exploit output location specific APIs for speed advantage, nevertheless removes redundant scheduler roundtrips. """ @classmethod def bulk_complete(cls, parameter_tuples): generated_tuples = [] for parameter_tuple in parameter_tuples: if isinstance(parameter_tuple, (list, tuple)): if cls(*parameter_tuple).complete(): generated_tuples.append(parameter_tuple) elif isinstance(parameter_tuple, dict): if cls(**parameter_tuple).complete(): generated_tuples.append(parameter_tuple) else: if cls(parameter_tuple).complete(): generated_tuples.append(parameter_tuple) return generated_tuples class DynamicRequirements(object): """ Wraps dynamic requirements yielded in tasks's run methods to control how completeness checks of (e.g.) large chunks of tasks are performed. Besides the wrapped *requirements*, instances of this class can be passed an optional function *custom_complete* that might implement an optimized check for completeness. If set, the function will be called with a single argument, *complete_fn*, which should be used to perform the per-task check. Example: .. code-block:: python class SomeTaskWithDynamicRequirements(luigi.Task): ... def run(self): large_chunk_of_tasks = [OtherTask(i=i) for i in range(10000)] def custom_complete(complete_fn): # example: assume OtherTask always write into the same directory, so just check # if the first task is complete, and compare basenames for the rest if not complete_fn(large_chunk_of_tasks[0]): return False paths = [task.output().path for task in large_chunk_of_tasks] basenames = os.listdir(os.path.dirname(paths[0])) # a single fs call return all(os.path.basename(path) in basenames for path in paths) yield DynamicRequirements(large_chunk_of_tasks, custom_complete) .. py:attribute:: requirements The original, wrapped requirements. .. py:attribute:: custom_complete The optional, custom function performing the completeness check of the wrapped requirements. """ def __init__(self, requirements, custom_complete=None): super().__init__() # store attributes self.requirements = requirements self.custom_complete = custom_complete # cached flat requirements and paths self._flat_requirements = None self._paths = None @property def flat_requirements(self): if self._flat_requirements is None: self._flat_requirements = flatten(self.requirements) return self._flat_requirements @property def paths(self): if self._paths is None: self._paths = getpaths(self.requirements) return self._paths def complete(self, complete_fn=None): # default completeness check if complete_fn is None: def complete_fn(task): return task.complete() # use the custom complete function when set if self.custom_complete: return self.custom_complete(complete_fn) # default implementation return all(complete_fn(t) for t in self.flat_requirements) class ExternalTask(Task): """ Subclass for references to external dependencies. An ExternalTask's does not have a `run` implementation, which signifies to the framework that this Task's :py:meth:`output` is generated outside of Luigi. """ run = None def externalize(taskclass_or_taskobject): """ Returns an externalized version of a Task. You may both pass an instantiated task object or a task class. Some examples: .. code-block:: python class RequiringTask(luigi.Task): def requires(self): task_object = self.clone(MyTask) return externalize(task_object) ... Here's mostly equivalent code, but ``externalize`` is applied to a task class instead. .. code-block:: python @luigi.util.requires(externalize(MyTask)) class RequiringTask(luigi.Task): pass ... Of course, it may also be used directly on classes and objects (for example for reexporting or other usage). .. code-block:: python MyTask = externalize(MyTask) my_task_2 = externalize(MyTask2(param='foo')) If you however want a task class to be external from the beginning, you're better off inheriting :py:class:`ExternalTask` rather than :py:class:`Task`. This function tries to be side-effect free by creating a copy of the class or the object passed in and then modify that object. In particular this code shouldn't do anything. .. code-block:: python externalize(MyTask) # BAD: This does nothing (as after luigi 2.4.0) """ copied_value = copy.copy(taskclass_or_taskobject) if copied_value is taskclass_or_taskobject: # Assume it's a class clazz = taskclass_or_taskobject @_task_wraps(clazz) class _CopyOfClass(clazz): # How to copy a class: http://stackoverflow.com/a/9541120/621449 _visible_in_registry = False _CopyOfClass.run = None return _CopyOfClass else: # We assume it's an object copied_value.run = None return copied_value class WrapperTask(Task): """ Use for tasks that only wrap other tasks and that by definition are done if all their requirements exist. """ def complete(self): return all(r.complete() for r in flatten(self.requires())) class Config(Task): """ Class for configuration. See :ref:`ConfigClasses`. """ # TODO: let's refactor Task & Config so that it inherits from a common # ParamContainer base class pass def getpaths(struct): """ Maps all Tasks in a structured data object to their .output(). """ if isinstance(struct, Task): return struct.output() elif isinstance(struct, dict): return struct.__class__((k, getpaths(v)) for k, v in struct.items()) elif isinstance(struct, (list, tuple)): return struct.__class__(getpaths(r) for r in struct) else: # Remaining case: assume struct is iterable... try: return [getpaths(r) for r in struct] except TypeError: raise Exception("Cannot map %s to Task/dict/list" % str(struct)) def flatten(struct): """ Creates a flat list of all items in structured output (dicts, lists, items): .. code-block:: python >>> sorted(flatten({'a': 'foo', 'b': 'bar'})) ['bar', 'foo'] >>> sorted(flatten(['foo', ['bar', 'troll']])) ['bar', 'foo', 'troll'] >>> flatten('foo') ['foo'] >>> flatten(42) [42] """ if struct is None: return [] flat = [] if isinstance(struct, dict): for _, result in struct.items(): flat += flatten(result) return flat if isinstance(struct, str): return [struct] try: # if iterable iterator = iter(struct) except TypeError: return [struct] for result in iterator: flat += flatten(result) return flat def flatten_output(task): """ Lists all output targets by recursively walking output-less (wrapper) tasks. """ output_tasks = OrderedDict() # OrderedDict used as ordered set tasks_to_process = deque([task]) while tasks_to_process: current_task = tasks_to_process.popleft() if flatten(current_task.output()): if current_task not in output_tasks: output_tasks[current_task] = None else: tasks_to_process.extend(flatten(current_task.requires())) return flatten(task.output() for task in output_tasks) def _task_wraps(task_class): # In order to make the behavior of a wrapper class nicer, we set the name of the # new class to the wrapped class, and copy over the docstring and module as well. # This makes it possible to pickle the wrapped class etc. # Btw, this is a slight abuse of functools.wraps. It's meant to be used only for # functions, but it works for classes too, if you pass updated=[] assigned = functools.WRAPPER_ASSIGNMENTS + ("_namespace_at_class_time",) return functools.wraps(task_class, assigned=assigned, updated=[]) ================================================ FILE: luigi/task_history.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ Abstract class for task history. Currently the only subclass is :py:class:`~luigi.db_task_history.DbTaskHistory`. """ import abc import logging logger = logging.getLogger("luigi-interface") class StoredTask: """ Interface for methods on TaskHistory """ # TODO : do we need this task as distinct from luigi.scheduler.Task? # this only records host and record_id in addition to task parameters. def __init__(self, task, status, host=None): self._task = task self.status = status self.record_id = None self.host = host @property def task_family(self): return self._task.family @property def parameters(self): return self._task.params class TaskHistory(metaclass=abc.ABCMeta): """ Abstract Base Class for updating the run history of a task """ @abc.abstractmethod def task_scheduled(self, task): pass @abc.abstractmethod def task_finished(self, task, successful): pass @abc.abstractmethod def task_started(self, task, worker_host): pass # TODO(erikbern): should web method (find_latest_runs etc) be abstract? class NopHistory(TaskHistory): def task_scheduled(self, task): pass def task_finished(self, task, successful): pass def task_started(self, task, worker_host): pass ================================================ FILE: luigi/task_register.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ Define the centralized register of all :class:`~luigi.task.Task` classes. """ import abc import logging from typing import Any, Dict, List logger = logging.getLogger("luigi-interface") class TaskClassException(Exception): pass class TaskClassNotFoundException(TaskClassException): pass class TaskClassAmbigiousException(TaskClassException): pass class Register(abc.ABCMeta): """ The Metaclass of :py:class:`Task`. Acts as a global registry of Tasks with the following properties: 1. Cache instances of objects so that eg. ``X(1, 2, 3)`` always returns the same object. 2. Keep track of all subclasses of :py:class:`Task` and expose them. """ __instance_cache: Dict[str, Any] = {} _default_namespace_dict: Dict[str, Any] = {} _reg: List[Any] = [] AMBIGUOUS_CLASS = object() # Placeholder denoting an error """If this value is returned by :py:meth:`_get_reg` then there is an ambiguous task name (two :py:class:`Task` have the same name). This denotes an error.""" def __new__(metacls, classname, bases, classdict, **kwargs): """ Custom class creation for namespacing. Also register all subclasses. When the set or inherited namespace evaluates to ``None``, set the task namespace to whatever the currently declared namespace is. """ cls = super(Register, metacls).__new__(metacls, classname, bases, classdict, **kwargs) cls._namespace_at_class_time = metacls._get_namespace(cls.__module__) metacls._reg.append(cls) return cls def __call__(cls, *args, **kwargs): """ Custom class instantiation utilizing instance cache. If a Task has already been instantiated with the same parameters, the previous instance is returned to reduce number of object instances. """ def instantiate(): return super(Register, cls).__call__(*args, **kwargs) h = cls.__instance_cache if h is None: # disabled return instantiate() params = cls.get_params() param_values = cls.get_param_values(params, args, kwargs) k = (cls, tuple(param_values)) try: hash(k) except TypeError: logger.debug("Not all parameter values are hashable so instance isn't coming from the cache") return instantiate() # unhashable types in parameters if k not in h: h[k] = instantiate() return h[k] @classmethod def clear_instance_cache(cls): """ Clear/Reset the instance cache. """ cls.__instance_cache = {} @classmethod def disable_instance_cache(cls): """ Disables the instance cache. """ cls.__instance_cache = None @property def task_family(cls): """ Internal note: This function will be deleted soon. """ task_namespace = cls.get_task_namespace() if not task_namespace: return cls.__name__ else: return f"{task_namespace}.{cls.__name__}" @classmethod def _get_reg(cls): """Return all of the registered classes. :return: an ``dict`` of task_family -> class """ # We have to do this on-demand in case task names have changed later reg = dict() for task_cls in cls._reg: if not task_cls._visible_in_registry: continue name = task_cls.get_task_family() if name in reg and ( reg[name] == Register.AMBIGUOUS_CLASS # Check so issubclass doesn't crash or not issubclass(task_cls, reg[name]) ): # Registering two different classes - this means we can't instantiate them by name # The only exception is if one class is a subclass of the other. In that case, we # instantiate the most-derived class (this fixes some issues with decorator wrappers). reg[name] = Register.AMBIGUOUS_CLASS else: reg[name] = task_cls return reg @classmethod def _set_reg(cls, reg): """The writing complement of _get_reg""" cls._reg = [task_cls for task_cls in reg.values() if task_cls is not cls.AMBIGUOUS_CLASS] @classmethod def task_names(cls): """ List of task names as strings """ return sorted(cls._get_reg().keys()) @classmethod def tasks_str(cls): """ Human-readable register contents dump. """ return ",".join(cls.task_names()) @classmethod def get_task_cls(cls, name): """ Returns an unambiguous class or raises an exception. """ task_cls = cls._get_reg().get(name) if not task_cls: raise TaskClassNotFoundException(cls._missing_task_msg(name)) if task_cls == cls.AMBIGUOUS_CLASS: raise TaskClassAmbigiousException("Task %r is ambiguous" % name) return task_cls @classmethod def get_all_params(cls): """ Compiles and returns all parameters for all :py:class:`Task`. :return: a generator of tuples (TODO: we should make this more elegant) """ for task_name, task_cls in cls._get_reg().items(): if task_cls == cls.AMBIGUOUS_CLASS: continue for param_name, param_obj in task_cls.get_params(): yield task_name, (not task_cls.use_cmdline_section), param_name, param_obj @staticmethod def _editdistance(a, b): """Simple unweighted Levenshtein distance""" r0 = range(0, len(b) + 1) r1 = [0] * (len(b) + 1) for i in range(0, len(a)): r1[0] = i + 1 for j in range(0, len(b)): c = 0 if a[i] is b[j] else 1 r1[j + 1] = min(r1[j] + 1, r0[j + 1] + 1, r0[j] + c) r0 = r1[:] return r1[len(b)] @classmethod def _missing_task_msg(cls, task_name): weighted_tasks = [(Register._editdistance(task_name, task_name_2), task_name_2) for task_name_2 in cls.task_names()] ordered_tasks = sorted(weighted_tasks, key=lambda pair: pair[0]) candidates = [task for (dist, task) in ordered_tasks if dist <= 5 and dist < len(task)] if candidates: return "No task %s. Did you mean:\n%s" % (task_name, "\n".join(candidates)) else: return "No task %s. Candidates are: %s" % (task_name, cls.tasks_str()) @classmethod def _get_namespace(mcs, module_name): for parent in mcs._module_parents(module_name): entry = mcs._default_namespace_dict.get(parent) if entry: return entry return "" # Default if nothing specifies @staticmethod def _module_parents(module_name): """ >>> list(Register._module_parents('a.b')) ['a.b', 'a', ''] """ spl = module_name.split(".") for i in range(len(spl), 0, -1): yield ".".join(spl[0:i]) if module_name: yield "" def load_task(module, task_name, params_str): """ Imports task dynamically given a module and a task name. """ if module is not None: __import__(module) task_cls = Register.get_task_cls(task_name) return task_cls.from_str_params(params_str) ================================================ FILE: luigi/task_status.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ Possible values for a Task's status in the Scheduler """ PENDING = "PENDING" FAILED = "FAILED" DONE = "DONE" RUNNING = "RUNNING" BATCH_RUNNING = "BATCH_RUNNING" SUSPENDED = "SUSPENDED" # Only kept for backward compatibility with old clients UNKNOWN = "UNKNOWN" DISABLED = "DISABLED" ================================================ FILE: luigi/templates/history.html ================================================ {% extends "layout.html" %} {% block content %}

    {{name}} History


    {% if statusResults and taskResults %} {% end %}
    {% end %} ================================================ FILE: luigi/templates/layout.html ================================================ Luigi History Viewer
    {% block content %}{% end %}
    ================================================ FILE: luigi/templates/menu.html ================================================ {% extends "layout.html" %} {% block content %}
    {% if tasknames %}

    [ Task History ]

    {% end %}
    {% end %} ================================================ FILE: luigi/templates/recent.html ================================================ {% extends "layout.html" %} {% block content %}

    Luigi Task History

    {% for task in tasks %} {% end %}
    Name Host Last Action Status Parameters
    {{task.name}} {{task.host}} {{task.events[0].ts}} {{task.events[0].event_name}} {% for (k, param) in task.parameters.items() %}
    {{k}}{{param.value}}
    {% end %}
    {% end %} ================================================ FILE: luigi/templates/show.html ================================================ {% extends "layout.html" %} {% block content %}

    Info

    Task Id {{task.id}}
    Task Name {{task.name}}
    Host {{task.host}}
    More All "{{task.name}}" runs

    Parameters

    {% for (k, param) in task.parameters.items() %} {% end %}
    Name Value
    {{k}} {{param.value}}

    Actions

    {% for event in task.events %} {% end %}
    Status Action Time
    {{event.event_name}} {{event.ts}}
    {% end %} ================================================ FILE: luigi/tools/__init__.py ================================================ # -*- coding: utf-8 -*- # Copyright (c) 2014 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); you may not # use this file except in compliance with the License. You may obtain a copy of # the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations under # the License. """ Sort of a standard library for doing stuff with Tasks at a somewhat abstract level. Submodule introduced to stop growing util.py unstructured. """ ================================================ FILE: luigi/tools/deps.py ================================================ #!/usr/bin/env python # Finds all tasks and task outputs on the dependency paths from the given downstream task T # up to the given source/upstream task S (optional). If the upstream task is not given, # all upstream tasks on all dependency paths of T will be returned. # Terms: # if the execution of Task T depends on the output of task S on a dependency graph, # T is called a downstream/sink task, S is called an upstream/source task. # This is useful and practical way to find all upstream tasks of task T. # For example suppose you have a daily computation that starts with a task named Daily. # And suppose you have another task named Aggregate. Daily triggers a few tasks # which eventually trigger Aggregate. Now, suppose you find a bug in Aggregate. # You fixed the bug and now you want to rerun it, including all it's upstream deps. # # To do that you run: # bin/deps.py --module daily_module Aggregate --daily-param1 xxx --upstream-family Daily # # This will output all the tasks on the dependency path between Daily and Aggregate. In # effect, this is how you find all upstream tasks for Aggregate. Now you can delete its # output and run Aggregate again. Daily will eventually trigget Aggregate and all tasks on # the way. # # The same code here might be used as a CLI tool as well as a python module. # In python, invoke find_deps(task, upstream_name) to get a set of all task instances on the # paths between task T and upstream task S. You can then use the task instances to delete their output or # perform other computation based on that. # # Example: # # PYTHONPATH=$PYTHONPATH:/path/to/your/luigi/tasks bin/deps.py \ # --module my.tasks MyDownstreamTask # --downstream_task_param1 123456 # [--upstream-family MyUpstreamTask] # import sys from collections.abc import Iterable import luigi.interface from luigi import parameter from luigi.cmdline_parser import CmdlineParser from luigi.contrib.postgres import PostgresTarget from luigi.contrib.s3 import S3Target from luigi.contrib.ssh import RemoteTarget from luigi.target import FileSystemTarget from luigi.task import flatten def get_task_requires(task): return set(flatten(task.requires())) def dfs_paths(start_task, goal_task_family, path=None): if path is None: path = [start_task] if start_task.task_family == goal_task_family or goal_task_family is None: for item in path: yield item for next in get_task_requires(start_task) - set(path): for t in dfs_paths(next, goal_task_family, path + [next]): yield t class upstream(luigi.task.Config): """ Used to provide the parameter upstream-family """ family = parameter.OptionalParameter(default=None) def find_deps(task, upstream_task_family): """ Finds all dependencies that start with the given task and have a path to upstream_task_family Returns all deps on all paths between task and upstream """ return {t for t in dfs_paths(task, upstream_task_family)} def find_deps_cli(): """ Finds all tasks on all paths from provided CLI task """ cmdline_args = sys.argv[1:] with CmdlineParser.global_instance(cmdline_args) as cp: return find_deps(cp.get_task_obj(), upstream().family) def get_task_output_description(task_output): """ Returns a task's output as a string """ output_description = "n/a" if isinstance(task_output, RemoteTarget): output_description = "[SSH] {0}:{1}".format(task_output._fs.remote_context.host, task_output.path) elif isinstance(task_output, S3Target): output_description = "[S3] {0}".format(task_output.path) elif isinstance(task_output, FileSystemTarget): output_description = "[FileSystem] {0}".format(task_output.path) elif isinstance(task_output, PostgresTarget): output_description = "[DB] {0}:{1}".format(task_output.host, task_output.table) else: output_description = "to be determined" return output_description def main(): deps = find_deps_cli() for task in deps: task_output = task.output() if isinstance(task_output, dict): output_descriptions = [get_task_output_description(output) for label, output in task_output.items()] elif isinstance(task_output, Iterable): output_descriptions = [get_task_output_description(output) for output in task_output] else: output_descriptions = [get_task_output_description(task_output)] print(" TASK: {0}".format(task)) for desc in output_descriptions: print(" : {0}".format(desc)) if __name__ == "__main__": main() ================================================ FILE: luigi/tools/deps_tree.py ================================================ # -*- coding: utf-8 -*- """ This module parses commands exactly the same as the luigi task runner. You must specify the module, the task and task parameters. Instead of executing a task, this module prints the significant parameters and state of the task and its dependencies in a tree format. Use this to visualize the execution plan in the terminal. .. code-block:: none $ luigi-deps-tree --module foo_complex examples.Foo ... └─--[Foo-{} (PENDING)] |---[Bar-{'num': '0'} (PENDING)] | |---[Bar-{'num': '4'} (PENDING)] | └─--[Bar-{'num': '5'} (PENDING)] |---[Bar-{'num': '1'} (PENDING)] └─--[Bar-{'num': '2'} (PENDING)] └─--[Bar-{'num': '6'} (PENDING)] |---[Bar-{'num': '7'} (PENDING)] | |---[Bar-{'num': '9'} (PENDING)] | └─--[Bar-{'num': '10'} (PENDING)] | └─--[Bar-{'num': '11'} (PENDING)] └─--[Bar-{'num': '8'} (PENDING)] └─--[Bar-{'num': '12'} (PENDING)] """ import sys import warnings from luigi.cmdline_parser import CmdlineParser from luigi.task import flatten class bcolors: """ colored output for task status """ OKBLUE = "\033[94m" OKGREEN = "\033[92m" ENDC = "\033[0m" def print_tree(task, indent="", last=True): """ Return a string representation of the tasks, their statuses/parameters in a dependency tree format """ # dont bother printing out warnings about tasks with no output with warnings.catch_warnings(): warnings.filterwarnings(action="ignore", message="Task .* without outputs has no custom complete\\(\\) method") is_task_complete = task.complete() is_complete = (bcolors.OKGREEN + "COMPLETE" if is_task_complete else bcolors.OKBLUE + "PENDING") + bcolors.ENDC name = task.__class__.__name__ params = task.to_str_params(only_significant=True) result = "\n" + indent if last: result += "└─--" indent += " " else: result += "|---" indent += "| " result += "[{0}-{1} ({2})]".format(name, params, is_complete) children = flatten(task.requires()) for index, child in enumerate(children): result += print_tree(child, indent, (index + 1) == len(children)) return result def main(): cmdline_args = sys.argv[1:] with CmdlineParser.global_instance(cmdline_args) as cp: task = cp.get_task_obj() print(print_tree(task)) if __name__ == "__main__": main() ================================================ FILE: luigi/tools/luigi_grep.py ================================================ #!/usr/bin/env python import argparse import json from collections import defaultdict from urllib.request import urlopen class LuigiGrep: def __init__(self, host, port): self._host = host self._port = port @property def graph_url(self): return "http://{0}:{1}/api/graph".format(self._host, self._port) def _fetch_json(self): """Returns the json representation of the dep graph""" print("Fetching from url: " + self.graph_url) resp = urlopen(self.graph_url).read() return json.loads(resp.decode("utf-8")) def _build_results(self, jobs, job): job_info = jobs[job] deps = job_info["deps"] deps_status = defaultdict(list) for j in deps: if j in jobs: deps_status[jobs[j]["status"]].append(j) else: deps_status["UNKNOWN"].append(j) return {"name": job, "status": job_info["status"], "deps_by_status": deps_status} def prefix_search(self, job_name_prefix): """Searches for jobs matching the given ``job_name_prefix``.""" json = self._fetch_json() jobs = json["response"] for job in jobs: if job.startswith(job_name_prefix): yield self._build_results(jobs, job) def status_search(self, status): """Searches for jobs matching the given ``status``.""" json = self._fetch_json() jobs = json["response"] for job in jobs: job_info = jobs[job] if job_info["status"].lower() == status.lower(): yield self._build_results(jobs, job) def main(): parser = argparse.ArgumentParser("luigi-grep is used to search for workflows using the luigi scheduler's json api") parser.add_argument("--scheduler-host", default="localhost", help="hostname of the luigi scheduler") parser.add_argument("--scheduler-port", default="8082", help="port of the luigi scheduler") parser.add_argument("--prefix", help="prefix of a task query to search for", default=None) parser.add_argument("--status", help="search for jobs with the given status", default=None) args = parser.parse_args() grep = LuigiGrep(args.scheduler_host, args.scheduler_port) results = [] if args.prefix: results = grep.prefix_search(args.prefix) elif args.status: results = grep.status_search(args.status) for job in results: print("{name}: {status}, Dependencies:".format(name=job["name"], status=job["status"])) for status, jobs in job["deps_by_status"].items(): print(" status={status}".format(status=status)) for job in jobs: print(" {job}".format(job=job)) if __name__ == "__main__": main() ================================================ FILE: luigi/tools/range.py ================================================ # -*- coding: utf-8 -*- # Copyright (c) 2014 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); you may not # use this file except in compliance with the License. You may obtain a copy of # the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations under # the License. """ Produces contiguous completed ranges of recurring tasks. See ``RangeDaily`` and ``RangeHourly`` for basic usage. Caveat - if gaps accumulate, their causes (e.g. missing dependencies) going unmonitored/unmitigated, then this will eventually keep retrying the same gaps over and over and make no progress to more recent times. (See ``task_limit`` and ``reverse`` parameters.) TODO foolproof against that kind of misuse? """ import functools import itertools import logging import re import time import warnings from collections import Counter from datetime import date, datetime, timedelta from dateutil.relativedelta import relativedelta import luigi from luigi.parameter import ParameterException from luigi.target import FileSystemTarget from luigi.task import Register, flatten_output logger = logging.getLogger("luigi-interface") class RangeEvent(luigi.Event): # Not sure if subclassing currently serves a purpose. Stringly typed, events are. """ Events communicating useful metrics. ``COMPLETE_COUNT`` would normally be nondecreasing, and its derivative would describe performance (how many instances complete invocation-over-invocation). ``COMPLETE_FRACTION`` reaching 1 would be a telling event in case of a backfill with defined start and stop. Would not be strikingly useful for a typical recurring task without stop defined, fluctuating close to 1. ``DELAY`` is measured from the first found missing datehour till (current time + hours_forward), or till stop if it is defined. In hours for Hourly. TBD different units for other frequencies? TODO any different for reverse mode? From first missing till last missing? From last gap till stop? """ COMPLETE_COUNT = "event.tools.range.complete.count" COMPLETE_FRACTION = "event.tools.range.complete.fraction" DELAY = "event.tools.range.delay" class RangeBase(luigi.WrapperTask): """ Produces a contiguous completed range of a recurring task. Made for the common use case where a task is parameterized by e.g. ``DateParameter``, and assurance is needed that any gaps arising from downtime are eventually filled. Emits events that one can use to monitor gaps and delays. At least one of start and stop needs to be specified. (This is quite an abstract base class for subclasses with different datetime parameter classes, e.g. ``DateParameter``, ``DateHourParameter``, ..., and different parameter naming, e.g. days_back/forward, hours_back/forward, ..., as well as different documentation wording, to improve user experience.) Subclasses will need to use the ``of`` parameter when overriding methods. """ # TODO lift the single parameter constraint by passing unknown parameters through WrapperTask? of = luigi.TaskParameter(description="task name to be completed. The task must take a single datetime parameter") of_params = luigi.DictParameter(default=dict(), description="Arguments to be provided to the 'of' class when instantiating") # The common parameters 'start' and 'stop' have type (e.g. DateParameter, # DateHourParameter) dependent on the concrete subclass, cumbersome to # define here generically without dark magic. Refer to the overrides. start = luigi.Parameter() stop = luigi.Parameter() reverse = luigi.BoolParameter( default=False, description="specifies the preferred order for catching up. False - work from the oldest missing outputs onward; True - from the newest backward", ) task_limit = luigi.IntParameter(default=50, description="how many of 'of' tasks to require. Guards against scheduling insane amounts of tasks in one go") # TODO overridable exclude_datetimes or something... now = luigi.IntParameter(default=None, description="set to override current time. In seconds since epoch") param_name = luigi.Parameter( default=None, description="parameter name used to pass in parameterized value. Defaults to None, meaning use first positional parameter", positional=False, ) @property def of_cls(self): """ DONT USE. Will be deleted soon. Use ``self.of``! """ if isinstance(self.of, str): warnings.warn('When using Range programatically, dont pass "of" param as string!') return Register.get_task_cls(self.of) return self.of # a bunch of datetime arithmetic building blocks that need to be provided in subclasses def datetime_to_parameter(self, dt): raise NotImplementedError def parameter_to_datetime(self, p): raise NotImplementedError def datetime_to_parameters(self, dt): """ Given a date-time, will produce a dictionary of of-params combined with the ranged task parameter """ raise NotImplementedError def parameters_to_datetime(self, p): """ Given a dictionary of parameters, will extract the ranged task parameter value """ raise NotImplementedError def moving_start(self, now): """ Returns a datetime from which to ensure contiguousness in the case when start is None or unfeasibly far back. """ raise NotImplementedError def moving_stop(self, now): """ Returns a datetime till which to ensure contiguousness in the case when stop is None or unfeasibly far forward. """ raise NotImplementedError def finite_datetimes(self, finite_start, finite_stop): """ Returns the individual datetimes in interval [finite_start, finite_stop) for which task completeness should be required, as a sorted list. """ raise NotImplementedError def _emit_metrics(self, missing_datetimes, finite_start, finite_stop): """ For consistent metrics one should consider the entire range, but it is open (infinite) if stop or start is None. Hence make do with metrics respective to the finite simplification. """ datetimes = self.finite_datetimes( finite_start if self.start is None else min(finite_start, self.parameter_to_datetime(self.start)), finite_stop if self.stop is None else max(finite_stop, self.parameter_to_datetime(self.stop)), ) delay_in_jobs = len(datetimes) - datetimes.index(missing_datetimes[0]) if datetimes and missing_datetimes else 0 self.trigger_event(RangeEvent.DELAY, self.of_cls.task_family, delay_in_jobs) expected_count = len(datetimes) complete_count = expected_count - len(missing_datetimes) self.trigger_event(RangeEvent.COMPLETE_COUNT, self.of_cls.task_family, complete_count) self.trigger_event(RangeEvent.COMPLETE_FRACTION, self.of_cls.task_family, float(complete_count) / expected_count if expected_count else 1) def _format_datetime(self, dt): return self.datetime_to_parameter(dt) def _format_range(self, datetimes): param_first = self._format_datetime(datetimes[0]) param_last = self._format_datetime(datetimes[-1]) return "[%s, %s]" % (param_first, param_last) def _instantiate_task_cls(self, param): return self.of(**self._task_parameters(param)) @property def _param_name(self): if self.param_name is None: return next(x[0] for x in self.of.get_params() if x[1].positional) else: return self.param_name def _task_parameters(self, param): kwargs = dict(**self.of_params) kwargs[self._param_name] = param return kwargs def requires(self): # cache because we anticipate a fair amount of computation if hasattr(self, "_cached_requires"): return self._cached_requires if not self.start and not self.stop: raise ParameterException("At least one of start and stop needs to be specified") if not self.start and not self.reverse: raise ParameterException("Either start needs to be specified or reverse needs to be True") if self.start and self.stop and self.start > self.stop: raise ParameterException("Can't have start > stop") # TODO check overridden complete() and exists() now = datetime.utcfromtimestamp(time.time() if self.now is None else self.now) moving_start = self.moving_start(now) finite_start = moving_start if self.start is None else max(self.parameter_to_datetime(self.start), moving_start) moving_stop = self.moving_stop(now) finite_stop = moving_stop if self.stop is None else min(self.parameter_to_datetime(self.stop), moving_stop) datetimes = self.finite_datetimes(finite_start, finite_stop) if finite_start <= finite_stop else [] if datetimes: logger.debug("Actually checking if range %s of %s is complete", self._format_range(datetimes), self.of_cls.task_family) missing_datetimes = sorted(self._missing_datetimes(datetimes)) logger.debug( "Range %s lacked %d of expected %d %s instances", self._format_range(datetimes), len(missing_datetimes), len(datetimes), self.of_cls.task_family ) else: missing_datetimes = [] logger.debug("Empty range. No %s instances expected", self.of_cls.task_family) self._emit_metrics(missing_datetimes, finite_start, finite_stop) if self.reverse: required_datetimes = missing_datetimes[-self.task_limit :] else: required_datetimes = missing_datetimes[: self.task_limit] if required_datetimes: logger.debug( "Requiring %d missing %s instances in range %s", len(required_datetimes), self.of_cls.task_family, self._format_range(required_datetimes) ) if self.reverse: required_datetimes.reverse() # TODO priorities, so that within the batch tasks are ordered too self._cached_requires = [self._instantiate_task_cls(self.datetime_to_parameter(d)) for d in required_datetimes] return self._cached_requires def missing_datetimes(self, finite_datetimes): """ Override in subclasses to do bulk checks. Returns a sorted list. This is a conservative base implementation that brutally checks completeness, instance by instance. Inadvisable as it may be slow. """ return [d for d in finite_datetimes if not self._instantiate_task_cls(self.datetime_to_parameter(d)).complete()] def _missing_datetimes(self, finite_datetimes): """ Backward compatible wrapper. Will be deleted eventually (stated on Dec 2015) """ try: return self.missing_datetimes(finite_datetimes) except TypeError as ex: if "missing_datetimes()" in repr(ex): warnings.warn("In your Range* subclass, missing_datetimes() should only take 1 argument (see latest docs)") return self.missing_datetimes(self.of_cls, finite_datetimes) else: raise class RangeDailyBase(RangeBase): """ Produces a contiguous completed range of a daily recurring task. """ start = luigi.DateParameter(default=None, description="beginning date, inclusive. Default: None - work backward forever (requires reverse=True)") stop = luigi.DateParameter(default=None, description="ending date, exclusive. Default: None - work forward forever") days_back = luigi.IntParameter( default=100, # slightly more than three months description=( "extent to which contiguousness is to be assured into " "past, in days from current time. Prevents infinite loop " "when start is none. If the dataset has limited retention" " (i.e. old outputs get removed), this should be set " "shorter to that, too, to prevent the oldest outputs " "flapping. Increase freely if you intend to process old " "dates - worker's memory is the limit" ), ) days_forward = luigi.IntParameter( default=0, description="extent to which contiguousness is to be assured into future, in days from current time. Prevents infinite loop when stop is none", ) def datetime_to_parameter(self, dt): return dt.date() def parameter_to_datetime(self, p): return datetime(p.year, p.month, p.day) def datetime_to_parameters(self, dt): """ Given a date-time, will produce a dictionary of of-params combined with the ranged task parameter """ return self._task_parameters(dt.date()) def parameters_to_datetime(self, p): """ Given a dictionary of parameters, will extract the ranged task parameter value """ dt = p[self._param_name] return datetime(dt.year, dt.month, dt.day) def moving_start(self, now): return now - timedelta(days=self.days_back) def moving_stop(self, now): return now + timedelta(days=self.days_forward) def finite_datetimes(self, finite_start, finite_stop): """ Simply returns the points in time that correspond to turn of day. """ date_start = datetime(finite_start.year, finite_start.month, finite_start.day) dates = [] for i in itertools.count(): t = date_start + timedelta(days=i) if t >= finite_stop: return dates if t >= finite_start: dates.append(t) class RangeHourlyBase(RangeBase): """ Produces a contiguous completed range of an hourly recurring task. """ start = luigi.DateHourParameter(default=None, description="beginning datehour, inclusive. Default: None - work backward forever (requires reverse=True)") stop = luigi.DateHourParameter(default=None, description="ending datehour, exclusive. Default: None - work forward forever") hours_back = luigi.IntParameter( default=100 * 24, # slightly more than three months description=( "extent to which contiguousness is to be assured into " "past, in hours from current time. Prevents infinite " "loop when start is none. If the dataset has limited " "retention (i.e. old outputs get removed), this should " "be set shorter to that, too, to prevent the oldest " "outputs flapping. Increase freely if you intend to " "process old dates - worker's memory is the limit" ), ) # TODO always entire interval for reprocessings (fixed start and stop)? hours_forward = luigi.IntParameter( default=0, description="extent to which contiguousness is to be assured into future, in hours from current time. Prevents infinite loop when stop is none", ) def datetime_to_parameter(self, dt): return dt def parameter_to_datetime(self, p): return p def datetime_to_parameters(self, dt): """ Given a date-time, will produce a dictionary of of-params combined with the ranged task parameter """ return self._task_parameters(dt) def parameters_to_datetime(self, p): """ Given a dictionary of parameters, will extract the ranged task parameter value """ return p[self._param_name] def moving_start(self, now): return now - timedelta(hours=self.hours_back) def moving_stop(self, now): return now + timedelta(hours=self.hours_forward) def finite_datetimes(self, finite_start, finite_stop): """ Simply returns the points in time that correspond to whole hours. """ datehour_start = datetime(finite_start.year, finite_start.month, finite_start.day, finite_start.hour) datehours = [] for i in itertools.count(): t = datehour_start + timedelta(hours=i) if t >= finite_stop: return datehours if t >= finite_start: datehours.append(t) def _format_datetime(self, dt): return luigi.DateHourParameter().serialize(dt) class RangeByMinutesBase(RangeBase): """ Produces a contiguous completed range of an recurring tasks separated a specified number of minutes. """ start = luigi.DateMinuteParameter( default=None, description="beginning date-hour-minute, inclusive. Default: None - work backward forever (requires reverse=True)" ) stop = luigi.DateMinuteParameter(default=None, description="ending date-hour-minute, exclusive. Default: None - work forward forever") minutes_back = luigi.IntParameter( default=60 * 24, # one day description=( "extent to which contiguousness is to be assured into " "past, in minutes from current time. Prevents infinite " "loop when start is none. If the dataset has limited " "retention (i.e. old outputs get removed), this should " "be set shorter to that, too, to prevent the oldest " "outputs flapping. Increase freely if you intend to " "process old dates - worker's memory is the limit" ), ) minutes_forward = luigi.IntParameter( default=0, description="extent to which contiguousness is to be assured into future, in minutes from current time. Prevents infinite loop when stop is none", ) minutes_interval = luigi.IntParameter(default=1, description="separation between events in minutes. It must evenly divide 60") def datetime_to_parameter(self, dt): return dt def parameter_to_datetime(self, p): return p def datetime_to_parameters(self, dt): """ Given a date-time, will produce a dictionary of of-params combined with the ranged task parameter """ return self._task_parameters(dt) def parameters_to_datetime(self, p): """ Given a dictionary of parameters, will extract the ranged task parameter value """ dt = p[self._param_name] return datetime(dt.year, dt.month, dt.day, dt.hour, dt.minute) def moving_start(self, now): return now - timedelta(minutes=self.minutes_back) def moving_stop(self, now): return now + timedelta(minutes=self.minutes_forward) def finite_datetimes(self, finite_start, finite_stop): """ Simply returns the points in time that correspond to a whole number of minutes intervals. """ # Validate that the minutes_interval can divide 60 and it is greater than 0 and lesser than 60 if not (0 < self.minutes_interval < 60): raise ParameterException("minutes-interval must be within 0..60") if 60 % self.minutes_interval != 0: raise ParameterException("minutes-interval does not evenly divide 60") # start of a complete interval, e.g. 20:13 and the interval is 5 -> 20:10 start_minute = int(finite_start.minute / self.minutes_interval) * self.minutes_interval datehour_start = datetime(year=finite_start.year, month=finite_start.month, day=finite_start.day, hour=finite_start.hour, minute=start_minute) datehours = [] for i in itertools.count(): t = datehour_start + timedelta(minutes=i * self.minutes_interval) if t >= finite_stop: return datehours if t >= finite_start: datehours.append(t) def _format_datetime(self, dt): return luigi.DateMinuteParameter().serialize(dt) def _constrain_glob(glob, paths, limit=5): """ Tweaks glob into a list of more specific globs that together still cover paths and not too much extra. Saves us minutes long listings for long dataset histories. Specifically, in this implementation the leftmost occurrences of "[0-9]" give rise to a few separate globs that each specialize the expression to digits that actually occur in paths. """ def digit_set_wildcard(chars): """ Makes a wildcard expression for the set, a bit readable, e.g. [1-5]. """ chars = sorted(chars) if len(chars) > 1 and ord(chars[-1]) - ord(chars[0]) == len(chars) - 1: return "[%s-%s]" % (chars[0], chars[-1]) else: return "[%s]" % "".join(chars) current = {glob: paths} while True: pos = list(current.keys())[0].find("[0-9]") if pos == -1: # no wildcard expressions left to specialize in the glob return list(current.keys()) char_sets = {} for g, p in current.items(): char_sets[g] = sorted({path[pos] for path in p}) if sum(len(s) for s in char_sets.values()) > limit: return [g.replace("[0-9]", digit_set_wildcard(char_sets[g]), 1) for g in current] for g, s in char_sets.items(): for c in s: new_glob = g.replace("[0-9]", c, 1) new_paths = list(filter(lambda p: p[pos] == c, current[g])) current[new_glob] = new_paths del current[g] def most_common(items): [(element, counter)] = Counter(items).most_common(1) return element, counter def _get_per_location_glob(tasks, outputs, regexes): """ Builds a glob listing existing output paths. Esoteric reverse engineering, but worth it given that (compared to an equivalent contiguousness guarantee by naive complete() checks) requests to the filesystem are cut by orders of magnitude, and users don't even have to retrofit existing tasks anyhow. """ paths = [o.path for o in outputs] # naive, because some matches could be confused by numbers earlier # in path, e.g. /foo/fifa2000k/bar/2000-12-31/00 matches = [r.search(p) for r, p in zip(regexes, paths)] for m, p, t in zip(matches, paths, tasks): if m is None: raise NotImplementedError("Couldn't deduce datehour representation in output path %r of task %s" % (p, t)) n_groups = len(matches[0].groups()) # the most common position of every group is likely # to be conclusive hit or miss positions = [most_common((m.start(i), m.end(i)) for m in matches)[0] for i in range(1, n_groups + 1)] glob = list(paths[0]) # FIXME sanity check that it's the same for all paths for start, end in positions: glob = glob[:start] + ["[0-9]"] * (end - start) + glob[end:] # chop off the last path item # (wouldn't need to if `hadoop fs -ls -d` equivalent were available) return "".join(glob).rsplit("/", 1)[0] def _get_filesystems_and_globs(datetime_to_task, datetime_to_re): """ Yields a (filesystem, glob) tuple per every output location of task. The task can have one or several FileSystemTarget outputs. For convenience, the task can be a luigi.WrapperTask, in which case outputs of all its dependencies are considered. """ # probe some scattered datetimes unlikely to all occur in paths, other than by being sincere datetime parameter's representations # TODO limit to [self.start, self.stop) so messages are less confusing? Done trivially it can kill correctness sample_datetimes = [datetime(y, m, d, h) for y in range(2000, 2050, 10) for m in range(1, 4) for d in range(5, 8) for h in range(21, 24)] regexes = [re.compile(datetime_to_re(d)) for d in sample_datetimes] sample_tasks = [datetime_to_task(d) for d in sample_datetimes] sample_outputs = [flatten_output(t) for t in sample_tasks] for o, t in zip(sample_outputs, sample_tasks): if len(o) != len(sample_outputs[0]): raise NotImplementedError("Outputs must be consistent over time, sorry; was %r for %r and %r for %r" % (o, t, sample_outputs[0], sample_tasks[0])) # TODO fall back on requiring last couple of days? to avoid astonishing blocking when changes like that are deployed # erm, actually it's not hard to test entire hours_back..hours_forward and split into consistent subranges FIXME? for target in o: if not isinstance(target, FileSystemTarget): raise NotImplementedError("Output targets must be instances of FileSystemTarget; was %r for %r" % (target, t)) for o in zip(*sample_outputs): # transposed, so here we're iterating over logical outputs, not datetimes glob = _get_per_location_glob(sample_tasks, o, regexes) yield o[0].fs, glob def _list_existing(filesystem, glob, paths): """ Get all the paths that do in fact exist. Returns a set of all existing paths. Takes a luigi.target.FileSystem object, a str which represents a glob and a list of strings representing paths. """ globs = _constrain_glob(glob, paths) time_start = time.time() listing = [] for g in sorted(globs): logger.debug("Listing %s", g) if filesystem.exists(g): listing.extend(filesystem.listdir(g)) logger.debug("%d %s listings took %f s to return %d items", len(globs), filesystem.__class__.__name__, time.time() - time_start, len(listing)) return set(listing) def infer_bulk_complete_from_fs(datetimes, datetime_to_task, datetime_to_re): """ Efficiently determines missing datetimes by filesystem listing. The current implementation works for the common case of a task writing output to a ``FileSystemTarget`` whose path is built using strftime with format like '...%Y...%m...%d...%H...', without custom ``complete()`` or ``exists()``. (Eventually Luigi could have ranges of completion as first-class citizens. Then this listing business could be factored away/be provided for explicitly in target API or some kind of a history server.) """ filesystems_and_globs_by_location = _get_filesystems_and_globs(datetime_to_task, datetime_to_re) paths_by_datetime = [[o.path for o in flatten_output(datetime_to_task(d))] for d in datetimes] listing = set() for (f, g), p in zip(filesystems_and_globs_by_location, zip(*paths_by_datetime)): # transposed, so here we're iterating over logical outputs, not datetimes listing |= _list_existing(f, g, p) # quickly learn everything that's missing missing_datetimes = [] for d, p in zip(datetimes, paths_by_datetime): if not set(p) <= listing: missing_datetimes.append(d) return missing_datetimes class RangeMonthly(RangeBase): """ Produces a contiguous completed range of a monthly recurring task. Unlike the Range* classes with shorter intervals, this class does not perform bulk optimisation. It is assumed that the number of months is low enough not to motivate the increased complexity. Hence, there is no class RangeMonthlyBase. """ start = luigi.MonthParameter(default=None, description="beginning month, inclusive. Default: None - work backward forever (requires reverse=True)") stop = luigi.MonthParameter(default=None, description="ending month, exclusive. Default: None - work forward forever") months_back = luigi.IntParameter( default=13, # Little over a year description=( "extent to which contiguousness is to be assured into " "past, in months from current time. Prevents infinite loop " "when start is none. If the dataset has limited retention" " (i.e. old outputs get removed), this should be set " "shorter to that, too, to prevent the oldest outputs " "flapping. Increase freely if you intend to process old " "dates - worker's memory is the limit" ), ) months_forward = luigi.IntParameter( default=0, description="extent to which contiguousness is to be assured into future, in months from current time. Prevents infinite loop when stop is none", ) def datetime_to_parameter(self, dt): return date(dt.year, dt.month, 1) def parameter_to_datetime(self, p): return datetime(p.year, p.month, 1) def datetime_to_parameters(self, dt): """ Given a date-time, will produce a dictionary of of-params combined with the ranged task parameter """ return self._task_parameters(dt.date()) def parameters_to_datetime(self, p): """ Given a dictionary of parameters, will extract the ranged task parameter value """ dt = p[self._param_name] return datetime(dt.year, dt.month, 1) def _format_datetime(self, dt): return dt.strftime("%Y-%m") def moving_start(self, now): return self._align(now) - relativedelta(months=self.months_back) def moving_stop(self, now): return self._align(now) + relativedelta(months=self.months_forward) def _align(self, dt): return datetime(dt.year, dt.month, 1) def finite_datetimes(self, finite_start, finite_stop): """ Simply returns the points in time that correspond to turn of month. """ start_date = self._align(finite_start) aligned_stop = self._align(finite_stop) dates = [] for m in itertools.count(): t = start_date + relativedelta(months=m) if t >= aligned_stop: return dates if t >= finite_start: dates.append(t) class RangeDaily(RangeDailyBase): """Efficiently produces a contiguous completed range of a daily recurring task that takes a single ``DateParameter``. Falls back to infer it from output filesystem listing to facilitate the common case usage. Convenient to use even from command line, like: .. code-block:: console luigi --module your.module RangeDaily --of YourActualTask --start 2014-01-01 """ def missing_datetimes(self, finite_datetimes): try: cls_with_params = functools.partial(self.of, **self.of_params) complete_parameters = self.of.bulk_complete.__func__(cls_with_params, map(self.datetime_to_parameter, finite_datetimes)) return set(finite_datetimes) - set(map(self.parameter_to_datetime, complete_parameters)) except NotImplementedError: return infer_bulk_complete_from_fs( finite_datetimes, lambda d: self._instantiate_task_cls(self.datetime_to_parameter(d)), lambda d: d.strftime("(%Y).*(%m).*(%d)") ) class RangeHourly(RangeHourlyBase): """Efficiently produces a contiguous completed range of an hourly recurring task that takes a single ``DateHourParameter``. Benefits from ``bulk_complete`` information to efficiently cover gaps. Falls back to infer it from output filesystem listing to facilitate the common case usage. Convenient to use even from command line, like: .. code-block:: console luigi --module your.module RangeHourly --of YourActualTask --start 2014-01-01T00 """ def missing_datetimes(self, finite_datetimes): try: # TODO: Why is there a list() here but not for the RangeDaily?? cls_with_params = functools.partial(self.of, **self.of_params) complete_parameters = self.of.bulk_complete.__func__(cls_with_params, list(map(self.datetime_to_parameter, finite_datetimes))) return set(finite_datetimes) - set(map(self.parameter_to_datetime, complete_parameters)) except NotImplementedError: return infer_bulk_complete_from_fs( finite_datetimes, lambda d: self._instantiate_task_cls(self.datetime_to_parameter(d)), lambda d: d.strftime("(%Y).*(%m).*(%d).*(%H)") ) class RangeByMinutes(RangeByMinutesBase): """Efficiently produces a contiguous completed range of an recurring task every interval minutes that takes a single ``DateMinuteParameter``. Benefits from ``bulk_complete`` information to efficiently cover gaps. Falls back to infer it from output filesystem listing to facilitate the common case usage. Convenient to use even from command line, like: .. code-block:: console luigi --module your.module RangeByMinutes --of YourActualTask --start 2014-01-01T0123 """ def missing_datetimes(self, finite_datetimes): try: cls_with_params = functools.partial(self.of, **self.of_params) complete_parameters = self.of.bulk_complete.__func__(cls_with_params, map(self.datetime_to_parameter, finite_datetimes)) return set(finite_datetimes) - set(map(self.parameter_to_datetime, complete_parameters)) except NotImplementedError: return infer_bulk_complete_from_fs( finite_datetimes, lambda d: self._instantiate_task_cls(self.datetime_to_parameter(d)), lambda d: d.strftime("(%Y).*(%m).*(%d).*(%H).*(%M)") ) ================================================ FILE: luigi/util.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ ============================================================ Using ``inherits`` and ``requires`` to ease parameter pain ============================================================ Most luigi plumbers will find themselves in an awkward task parameter situation at some point or another. Consider the following "parameter explosion" problem: .. code-block:: python class TaskA(luigi.ExternalTask): param_a = luigi.Parameter() def output(self): return luigi.LocalTarget('/tmp/log-{t.param_a}'.format(t=self)) class TaskB(luigi.Task): param_b = luigi.Parameter() param_a = luigi.Parameter() def requires(self): return TaskA(param_a=self.param_a) class TaskC(luigi.Task): param_c = luigi.Parameter() param_b = luigi.Parameter() param_a = luigi.Parameter() def requires(self): return TaskB(param_b=self.param_b, param_a=self.param_a) In work flows requiring many tasks to be chained together in this manner, parameter handling can spiral out of control. Each downstream task becomes more burdensome than the last. Refactoring becomes more difficult. There are several ways one might try and avoid the problem. **Approach 1**: Parameters via command line or config instead of :func:`~luigi.task.Task.requires`. .. code-block:: python class TaskA(luigi.ExternalTask): param_a = luigi.Parameter() def output(self): return luigi.LocalTarget('/tmp/log-{t.param_a}'.format(t=self)) class TaskB(luigi.Task): param_b = luigi.Parameter() def requires(self): return TaskA() class TaskC(luigi.Task): param_c = luigi.Parameter() def requires(self): return TaskB() Then run in the shell like so: .. code-block:: bash luigi --module my_tasks TaskC --param-c foo --TaskB-param-b bar --TaskA-param-a baz Repetitive parameters have been eliminated, but at the cost of making the job's command line interface slightly clunkier. Often this is a reasonable trade-off. But parameters can't always be refactored out every class. Downstream tasks might also need to use some of those parameters. For example, if ``TaskC`` needs to use ``param_a`` too, then ``param_a`` would still need to be repeated. **Approach 2**: Use a common parameter class .. code-block:: python class Params(luigi.Config): param_c = luigi.Parameter() param_b = luigi.Parameter() param_a = luigi.Parameter() class TaskA(Params, luigi.ExternalTask): def output(self): return luigi.LocalTarget('/tmp/log-{t.param_a}'.format(t=self)) class TaskB(Params): def requires(self): return TaskA() class TaskB(Params): def requires(self): return TaskB() This looks great at first glance, but a couple of issues lurk. Now ``TaskA`` and ``TaskB`` have unnecessary significant parameters. Significant parameters help define the identity of a task. Identical tasks are prevented from running at the same time by the central planner. This helps preserve the idempotent and atomic nature of luigi tasks. Unnecessary significant task parameters confuse a task's identity. Under the right circumstances, task identity confusion could lead to that task running when it shouldn't, or failing to run when it should. This approach should only be used when all of the parameters of the config class, are significant (or all insignificant) for all of its subclasses. And wait a second... there's a bug in the above code. See it? ``TaskA`` won't behave as an ``ExternalTask`` because the parent classes are specified in the wrong order. This contrived example is easy to fix (by swapping the ordering of the parents of ``TaskA``), but real world cases can be more difficult to both spot and fix. Inheriting from multiple classes derived from :class:`~luigi.task.Task` should be undertaken with caution and avoided where possible. **Approach 3**: Use :class:`~luigi.util.inherits` and :class:`~luigi.util.requires` The :class:`~luigi.util.inherits` class decorator in this module copies parameters (and nothing else) from one task class to another, and avoids direct pythonic inheritance. .. code-block:: python import luigi from luigi.util import inherits class TaskA(luigi.ExternalTask): param_a = luigi.Parameter() def output(self): return luigi.LocalTarget('/tmp/log-{t.param_a}'.format(t=self)) @inherits(TaskA) class TaskB(luigi.Task): param_b = luigi.Parameter() def requires(self): t = self.clone(TaskA) # or t = self.clone_parent() # Wait... whats this clone thingy do? # # Pass it a task class. It calls that task. And when it does, it # supplies all parameters (and only those parameters) common to # the caller and callee! # # The call to clone is equivalent to the following (note the # fact that clone avoids passing param_b). # # return TaskA(param_a=self.param_a) return t @inherits(TaskB) class TaskC(luigi.Task): param_c = luigi.Parameter() def requires(self): return self.clone(TaskB) This totally eliminates the need to repeat parameters, avoids inheritance issues, and keeps the task command line interface as simple (as it can be, anyway). Refactoring task parameters is also much easier. The :class:`~luigi.util.requires` helper function can reduce this pattern even further. It does everything :class:`~luigi.util.inherits` does, and also attaches a :class:`~luigi.util.requires` method to your task (still all without pythonic inheritance). But how does it know how to invoke the upstream task? It uses :func:`~luigi.task.Task.clone` behind the scenes! .. code-block:: python import luigi from luigi.util import inherits, requires class TaskA(luigi.ExternalTask): param_a = luigi.Parameter() def output(self): return luigi.LocalTarget('/tmp/log-{t.param_a}'.format(t=self)) @requires(TaskA) class TaskB(luigi.Task): param_b = luigi.Parameter() # The class decorator does this for me! # def requires(self): # return self.clone(TaskA) Use these helper functions effectively to avoid unnecessary repetition and dodge a few potentially nasty workflow pitfalls at the same time. Brilliant! """ import datetime import logging from luigi import parameter, task logger = logging.getLogger("luigi-interface") def common_params(task_instance, task_cls): """ Grab all the values in task_instance that are found in task_cls. """ if not isinstance(task_cls, task.Register): raise TypeError("task_cls must be an uninstantiated Task") task_instance_param_names = dict(task_instance.get_params()).keys() task_cls_params_dict = dict(task_cls.get_params()) task_cls_param_names = task_cls_params_dict.keys() common_param_names = set(task_instance_param_names).intersection(set(task_cls_param_names)) common_param_vals = [(key, task_cls_params_dict[key]) for key in common_param_names] common_kwargs = dict((key, task_instance.param_kwargs[key]) for key in common_param_names) vals = dict(task_instance.get_param_values(common_param_vals, [], common_kwargs)) return vals class inherits: """ Task inheritance. *New after Luigi 2.7.6:* multiple arguments support. Usage: .. code-block:: python class AnotherTask(luigi.Task): m = luigi.IntParameter() class YetAnotherTask(luigi.Task): n = luigi.IntParameter() @inherits(AnotherTask) class MyFirstTask(luigi.Task): def requires(self): return self.clone_parent() def run(self): print self.m # this will be defined # ... @inherits(AnotherTask, YetAnotherTask) class MySecondTask(luigi.Task): def requires(self): return self.clone_parents() def run(self): print self.n # this will be defined # ... """ def __init__(self, *tasks_to_inherit, **kw_tasks_to_inherit): super(inherits, self).__init__() if not tasks_to_inherit and not kw_tasks_to_inherit: raise TypeError("tasks_to_inherit or kw_tasks_to_inherit must contain at least one task") if tasks_to_inherit and kw_tasks_to_inherit: raise TypeError("Only one of tasks_to_inherit or kw_tasks_to_inherit may be present") self.tasks_to_inherit = tasks_to_inherit self.kw_tasks_to_inherit = kw_tasks_to_inherit def __call__(self, task_that_inherits): # Get all parameter objects from each of the underlying tasks task_iterator = self.tasks_to_inherit or self.kw_tasks_to_inherit.values() for task_to_inherit in task_iterator: for param_name, param_obj in task_to_inherit.get_params(): # Check if the parameter exists in the inheriting task if not hasattr(task_that_inherits, param_name): # If not, add it to the inheriting task setattr(task_that_inherits, param_name, param_obj) # Modify task_that_inherits by adding methods # Handle unnamed tasks as a list, named as a dictionary if self.tasks_to_inherit: def clone_parent(_self, **kwargs): return _self.clone(cls=self.tasks_to_inherit[0], **kwargs) task_that_inherits.clone_parent = clone_parent def clone_parents(_self, **kwargs): return [_self.clone(cls=task_to_inherit, **kwargs) for task_to_inherit in self.tasks_to_inherit] task_that_inherits.clone_parents = clone_parents elif self.kw_tasks_to_inherit: # Even if there is just one named task, return a dictionary def clone_parents(_self, **kwargs): return {task_name: _self.clone(cls=task_to_inherit, **kwargs) for task_name, task_to_inherit in self.kw_tasks_to_inherit.items()} task_that_inherits.clone_parents = clone_parents return task_that_inherits class requires: """ Same as :class:`~luigi.util.inherits`, but also auto-defines the requires method. *New after Luigi 2.7.6:* multiple arguments support. """ def __init__(self, *tasks_to_require, **kw_tasks_to_require): super(requires, self).__init__() self.tasks_to_require = tasks_to_require self.kw_tasks_to_require = kw_tasks_to_require def __call__(self, task_that_requires): task_that_requires = inherits(*self.tasks_to_require, **self.kw_tasks_to_require)(task_that_requires) # Modify task_that_requires by adding requires method. # If only one task is required, this single task is returned. # Otherwise, list of tasks is returned def requires(_self): return _self.clone_parent() if len(self.tasks_to_require) == 1 else _self.clone_parents() task_that_requires.requires = requires return task_that_requires class copies: """ Auto-copies a task. Usage: .. code-block:: python @copies(MyTask): class CopyOfMyTask(luigi.Task): def output(self): return LocalTarget(self.date.strftime('/var/xyz/report-%Y-%m-%d')) """ def __init__(self, task_to_copy): super(copies, self).__init__() self.requires_decorator = requires(task_to_copy) def __call__(self, task_that_copies): task_that_copies = self.requires_decorator(task_that_copies) # Modify task_that_copies by subclassing it and adding methods @task._task_wraps(task_that_copies) class Wrapped(task_that_copies): def run(_self): i, o = _self.input(), _self.output() f = o.open("w") # TODO: assert that i, o are Target objects and not complex datastructures for line in i.open("r"): f.write(line) f.close() return Wrapped def delegates(task_that_delegates): """Lets a task call methods on subtask(s). The way this works is that the subtask is run as a part of the task, but the task itself doesn't have to care about the requirements of the subtasks. The subtask doesn't exist from the scheduler's point of view, and its dependencies are instead required by the main task. Example: .. code-block:: python class PowersOfN(luigi.Task): n = luigi.IntParameter() def f(self, x): return x ** self.n @delegates class T(luigi.Task): def subtasks(self): return PowersOfN(5) def run(self): print self.subtasks().f(42) """ if not hasattr(task_that_delegates, "subtasks"): # This method can (optionally) define a couple of delegate tasks that # will be accessible as interfaces, meaning that the task can access # those tasks and run methods defined on them, etc raise AttributeError('%s needs to implement the method "subtasks"' % task_that_delegates) @task._task_wraps(task_that_delegates) class Wrapped(task_that_delegates): def deps(self): # Overrides method in base class return task.flatten(self.requires()) + task.flatten([t.deps() for t in task.flatten(self.subtasks())]) def run(self): for t in task.flatten(self.subtasks()): t.run() task_that_delegates.run(self) return Wrapped def previous(task): """ Return a previous Task of the same family. By default checks if this task family only has one non-global parameter and if it is a DateParameter, DateHourParameter or DateIntervalParameter in which case it returns with the time decremented by 1 (hour, day or interval) """ params = task.get_params() previous_params = {} previous_date_params = {} for param_name, param_obj in params: param_value = getattr(task, param_name) if isinstance(param_obj, parameter.DateParameter): previous_date_params[param_name] = param_value - datetime.timedelta(days=1) elif isinstance(param_obj, parameter.DateSecondParameter): previous_date_params[param_name] = param_value - datetime.timedelta(seconds=1) elif isinstance(param_obj, parameter.DateMinuteParameter): previous_date_params[param_name] = param_value - datetime.timedelta(minutes=1) elif isinstance(param_obj, parameter.DateHourParameter): previous_date_params[param_name] = param_value - datetime.timedelta(hours=1) elif isinstance(param_obj, parameter.DateIntervalParameter): previous_date_params[param_name] = param_value.prev() else: previous_params[param_name] = param_value previous_params.update(previous_date_params) if len(previous_date_params) == 0: raise NotImplementedError("No task parameter - can't determine previous task") elif len(previous_date_params) > 1: raise NotImplementedError("Too many date-related task parameters - can't determine previous task") else: return task.clone(**previous_params) def get_previous_completed(task, max_steps=10): prev = task for _ in range(max_steps): prev = previous(prev) logger.debug("Checking if %s is complete", prev) if prev.complete(): return prev return None ================================================ FILE: luigi/worker.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ The worker communicates with the scheduler and does two things: 1. Sends all tasks that has to be run 2. Gets tasks from the scheduler that should be run When running in local mode, the worker talks directly to a :py:class:`~luigi.scheduler.Scheduler` instance. When you run a central server, the worker will talk to the scheduler using a :py:class:`~luigi.rpc.RemoteScheduler` instance. Everything in this module is private to luigi and may change in incompatible ways between versions. The exception is the exception types and the :py:class:`worker` config class. """ import collections import collections.abc import contextlib import datetime import functools import getpass import importlib import json import logging import multiprocessing import os import queue as Queue import random import signal import socket import subprocess import sys import threading import time import traceback from luigi import notifications from luigi.event import Event from luigi.parameter import BoolParameter, FloatParameter, IntParameter, OptionalParameter, Parameter, TimeDeltaParameter from luigi.scheduler import DISABLED, DONE, FAILED, PENDING, UNKNOWN, WORKER_STATE_ACTIVE, WORKER_STATE_DISABLED, RetryPolicy, Scheduler from luigi.target import Target from luigi.task import Config, DynamicRequirements, Task, flatten from luigi.task_register import TaskClassException, load_task from luigi.task_status import RUNNING logger = logging.getLogger("luigi-interface") # Prevent fork() from being called during a C-level getaddrinfo() which uses a process-global mutex, # that may not be unlocked in child process, resulting in the process being locked indefinitely. fork_lock = threading.Lock() # Why we assert on _WAIT_INTERVAL_EPS: # multiprocessing.Queue.get() is undefined for timeout=0 it seems: # https://docs.python.org/3.4/library/multiprocessing.html#multiprocessing.Queue.get. # I also tried with really low epsilon, but then ran into the same issue where # the test case "test_external_dependency_worker_is_patient" got stuck. So I # unscientifically just set the final value to a floating point number that # "worked for me". _WAIT_INTERVAL_EPS = 0.00001 def _is_external(task): return task.run is None or task.run == NotImplemented def _get_retry_policy_dict(task): return RetryPolicy(task.retry_count, task.disable_hard_timeout, task.disable_window)._asdict() class TaskException(Exception): pass GetWorkResponse = collections.namedtuple( "GetWorkResponse", ( "task_id", "running_tasks", "n_pending_tasks", "n_unique_pending", "n_pending_last_scheduled", "worker_state", ), ) class TaskProcess(multiprocessing.Process): """Wrap all task execution in this class. Mainly for convenience since this is run in a separate process.""" # mapping of status_reporter attributes to task attributes that are added to tasks # before they actually run, and removed afterwards forward_reporter_attributes = { "update_tracking_url": "set_tracking_url", "update_status_message": "set_status_message", "update_progress_percentage": "set_progress_percentage", "decrease_running_resources": "decrease_running_resources", "scheduler_messages": "scheduler_messages", } def __init__( self, task, worker_id, result_queue, status_reporter, use_multiprocessing=False, worker_timeout=0, check_unfulfilled_deps=True, check_complete_on_run=False, task_completion_cache=None, ): super(TaskProcess, self).__init__() self.task = task self.worker_id = worker_id self.result_queue = result_queue self.status_reporter = status_reporter self.worker_timeout = task.worker_timeout if task.worker_timeout is not None else worker_timeout self.timeout_time = time.time() + self.worker_timeout if self.worker_timeout else None self.use_multiprocessing = use_multiprocessing or self.timeout_time is not None self.check_unfulfilled_deps = check_unfulfilled_deps self.check_complete_on_run = check_complete_on_run self.task_completion_cache = task_completion_cache # completeness check using the cache self.check_complete = functools.partial(check_complete_cached, completion_cache=task_completion_cache) def _run_get_new_deps(self): task_gen = self.task.run() if not isinstance(task_gen, collections.abc.Generator): return None next_send = None while True: try: if next_send is None: requires = next(task_gen) else: requires = task_gen.send(next_send) except StopIteration: return None # if requires is not a DynamicRequirements, create one to use its default behavior if not isinstance(requires, DynamicRequirements): requires = DynamicRequirements(requires) if not requires.complete(self.check_complete): # not all requirements are complete, return them which adds them to the tree new_deps = [(t.task_module, t.task_family, t.to_str_params()) for t in requires.flat_requirements] return new_deps # get the next generator result next_send = requires.paths def run(self): logger.info("[pid %s] Worker %s running %s", os.getpid(), self.worker_id, self.task) if self.use_multiprocessing: # Need to have different random seeds if running in separate processes processID = os.getpid() currentTime = time.time() random.seed(processID * currentTime) status = FAILED expl = "" missing = [] new_deps = [] try: # Verify that all the tasks are fulfilled! For external tasks we # don't care about unfulfilled dependencies, because we are just # checking completeness of self.task so outputs of dependencies are # irrelevant. if self.check_unfulfilled_deps and not _is_external(self.task): missing = [] for dep in self.task.deps(): if not self.check_complete(dep): nonexistent_outputs = [output for output in flatten(dep.output()) if not output.exists()] if nonexistent_outputs: missing.append(f"{dep.task_id} ({', '.join(map(str, nonexistent_outputs))})") else: missing.append(dep.task_id) if missing: deps = "dependency" if len(missing) == 1 else "dependencies" raise RuntimeError("Unfulfilled %s at run time: %s" % (deps, ", ".join(missing))) self.task.trigger_event(Event.START, self.task) t0 = time.time() status = None if _is_external(self.task): # External task if self.check_complete(self.task): status = DONE else: status = FAILED expl = "Task is an external data dependency and data does not exist (yet?)." else: with self._forward_attributes(): new_deps = self._run_get_new_deps() if not new_deps: if not self.check_complete_on_run: # update the cache if self.task_completion_cache is not None: self.task_completion_cache[self.task.task_id] = True status = DONE elif self.check_complete(self.task): status = DONE else: raise TaskException("Task finished running, but complete() is still returning false.") else: status = PENDING if new_deps: logger.info("[pid %s] Worker %s new requirements %s", os.getpid(), self.worker_id, self.task) elif status == DONE: self.task.trigger_event(Event.PROCESSING_TIME, self.task, time.time() - t0) expl = self.task.on_success() logger.info("[pid %s] Worker %s done %s", os.getpid(), self.worker_id, self.task) self.task.trigger_event(Event.SUCCESS, self.task) except KeyboardInterrupt: raise except BaseException as ex: status = FAILED expl = self._handle_run_exception(ex) finally: self.result_queue.put((self.task.task_id, status, expl, missing, new_deps)) def _handle_run_exception(self, ex): logger.exception("[pid %s] Worker %s failed %s", os.getpid(), self.worker_id, self.task) self.task.trigger_event(Event.FAILURE, self.task, ex) return self.task.on_failure(ex) def _recursive_terminate(self): import psutil try: parent = psutil.Process(self.pid) children = parent.children(recursive=True) # terminate parent. Give it a chance to clean up super(TaskProcess, self).terminate() parent.wait() # terminate children for child in children: try: child.terminate() except psutil.NoSuchProcess: continue except psutil.NoSuchProcess: return def terminate(self): """Terminate this process and its subprocesses.""" # default terminate() doesn't cleanup child processes, it orphans them. try: return self._recursive_terminate() except ImportError: return super(TaskProcess, self).terminate() @contextlib.contextmanager def _forward_attributes(self): # forward configured attributes to the task for reporter_attr, task_attr in self.forward_reporter_attributes.items(): setattr(self.task, task_attr, getattr(self.status_reporter, reporter_attr)) try: yield self finally: # reset attributes again for reporter_attr, task_attr in self.forward_reporter_attributes.items(): setattr(self.task, task_attr, None) # This code and the task_process_context config key currently feels a bit ad-hoc. # Discussion on generalizing it into a plugin system: https://github.com/spotify/luigi/issues/1897 class ContextManagedTaskProcess(TaskProcess): def __init__(self, context, *args, **kwargs): super(ContextManagedTaskProcess, self).__init__(*args, **kwargs) self.context = context def run(self): if self.context: logger.debug("Importing module and instantiating " + self.context) module_path, class_name = self.context.rsplit(".", 1) module = importlib.import_module(module_path) cls = getattr(module, class_name) with cls(self): super(ContextManagedTaskProcess, self).run() else: super(ContextManagedTaskProcess, self).run() class TaskStatusReporter: """ Reports task status information to the scheduler. This object must be pickle-able for passing to `TaskProcess` on systems where fork method needs to pickle the process object (e.g. Windows). """ def __init__(self, scheduler, task_id, worker_id, scheduler_messages): self._task_id = task_id self._worker_id = worker_id self._scheduler = scheduler self.scheduler_messages = scheduler_messages def update_tracking_url(self, tracking_url): self._scheduler.add_task(task_id=self._task_id, worker=self._worker_id, status=RUNNING, tracking_url=tracking_url) def update_status_message(self, message): self._scheduler.set_task_status_message(self._task_id, message) def update_progress_percentage(self, percentage): self._scheduler.set_task_progress_percentage(self._task_id, percentage) def decrease_running_resources(self, decrease_resources): self._scheduler.decrease_running_task_resources(self._task_id, decrease_resources) def report_task_statistics(self, statistics): self._scheduler.report_task_statistics(self._task_id, statistics) class SchedulerMessage: """ Message object that is build by the the :py:class:`Worker` when a message from the scheduler is received and passed to the message queue of a :py:class:`Task`. """ def __init__(self, scheduler, task_id, message_id, content, **payload): super(SchedulerMessage, self).__init__() self._scheduler = scheduler self._task_id = task_id self._message_id = message_id self.content = content self.payload = payload def __str__(self): return str(self.content) def __eq__(self, other): return self.content == other def respond(self, response): self._scheduler.add_scheduler_message_response(self._task_id, self._message_id, response) class SingleProcessPool: """ Dummy process pool for using a single processor. Imitates the api of multiprocessing.Pool using single-processor equivalents. """ def apply_async(self, function, args): return function(*args) def close(self): pass def join(self): pass class DequeQueue(collections.deque): """ deque wrapper implementing the Queue interface. """ def put(self, obj, block=None, timeout=None): return self.append(obj) def get(self, block=None, timeout=None): try: return self.pop() except IndexError: raise Queue.Empty class AsyncCompletionException(Exception): """ Exception indicating that something went wrong with checking complete. """ def __init__(self, trace): self.trace = trace class TracebackWrapper: """ Class to wrap tracebacks so we can know they're not just strings. """ def __init__(self, trace): self.trace = trace def check_complete_cached(task, completion_cache=None): # check if cached and complete cache_key = task.task_id if completion_cache is not None and completion_cache.get(cache_key): return True # (re-)check the status is_complete = task.complete() # tell the cache when complete if completion_cache is not None and is_complete: completion_cache[cache_key] = is_complete return is_complete def check_complete(task, out_queue, completion_cache=None): """ Checks if task is complete, puts the result to out_queue, optionally using the completion cache. """ logger.debug("Checking if %s is complete", task) try: is_complete = check_complete_cached(task, completion_cache) except Exception: is_complete = TracebackWrapper(traceback.format_exc()) out_queue.put((task, is_complete)) class worker(Config): # NOTE: `section.config-variable` in the config_path argument is deprecated in favor of `worker.config_variable` id = Parameter(default="", description="Override the auto-generated worker_id") ping_interval = FloatParameter(default=1.0, config_path=dict(section="core", name="worker-ping-interval")) keep_alive = BoolParameter(default=False, config_path=dict(section="core", name="worker-keep-alive")) count_uniques = BoolParameter( default=False, config_path=dict(section="core", name="worker-count-uniques"), description="worker-count-uniques means that we will keep a worker alive only if it has a unique pending task, as well as having keep-alive true", ) count_last_scheduled = BoolParameter(default=False, description="Keep a worker alive only if there are pending tasks which it was the last to schedule.") wait_interval = FloatParameter(default=1.0, config_path=dict(section="core", name="worker-wait-interval")) wait_jitter = FloatParameter(default=5.0) max_keep_alive_idle_duration = TimeDeltaParameter(default=datetime.timedelta(0)) max_reschedules = IntParameter(default=1, config_path=dict(section="core", name="worker-max-reschedules")) timeout = IntParameter(default=0, config_path=dict(section="core", name="worker-timeout")) task_limit = IntParameter(default=None, config_path=dict(section="core", name="worker-task-limit")) retry_external_tasks = BoolParameter( default=False, config_path=dict(section="core", name="retry-external-tasks"), description="If true, incomplete external tasks will be retested for completion while Luigi is running.", ) send_failure_email = BoolParameter(default=True, description="If true, send e-mails directly from the workeron failure") no_install_shutdown_handler = BoolParameter(default=False, description="If true, the SIGUSR1 shutdown handler willNOT be install on the worker") check_unfulfilled_deps = BoolParameter(default=True, description="If true, check for completeness of dependencies before running a task") check_complete_on_run = BoolParameter( default=False, description="If true, only mark tasks as done after running if they are complete. " "Regardless of this setting, the worker will always check if external " "tasks are complete before marking them as done.", ) force_multiprocessing = BoolParameter(default=False, description="If true, use multiprocessing also when running with 1 worker") task_process_context = OptionalParameter( default=None, description="If set to a fully qualified class name, the class will " "be instantiated with a TaskProcess as its constructor parameter and " "applied as a context manager around its run() call, so this can be " "used for obtaining high level customizable monitoring or logging of " "each individual Task run.", ) cache_task_completion = BoolParameter( default=False, description="If true, cache the response of successful completion checks " "of tasks assigned to a worker. This can especially speed up tasks with " "dynamic dependencies but assumes that the completion status does not change " "after it was true the first time.", ) class KeepAliveThread(threading.Thread): """ Periodically tell the scheduler that the worker still lives. """ def __init__(self, scheduler, worker_id, ping_interval, rpc_message_callback): super(KeepAliveThread, self).__init__() self._should_stop = threading.Event() self._scheduler = scheduler self._worker_id = worker_id self._ping_interval = ping_interval self._rpc_message_callback = rpc_message_callback def stop(self): self._should_stop.set() def run(self): while True: self._should_stop.wait(self._ping_interval) if self._should_stop.is_set(): logger.info("Worker %s was stopped. Shutting down Keep-Alive thread" % self._worker_id) break with fork_lock: response = None try: response = self._scheduler.ping(worker=self._worker_id) except BaseException: # httplib.BadStatusLine: logger.warning("Failed pinging scheduler") # handle rpc messages if response: for message in response["rpc_messages"]: self._rpc_message_callback(message) def rpc_message_callback(fn): fn.is_rpc_message_callback = True return fn class Worker: """ Worker object communicates with a scheduler. Simple class that talks to a scheduler and: * tells the scheduler what it has to do + its dependencies * asks for stuff to do (pulls it in a loop and runs it) """ def __init__(self, scheduler=None, worker_id=None, worker_processes=1, assistant=False, **kwargs): if scheduler is None: scheduler = Scheduler() self.worker_processes = int(worker_processes) self._worker_info = self._generate_worker_info() self._config = worker(**kwargs) worker_id = worker_id or self._config.id or self._generate_worker_id(self._worker_info) assert self._config.wait_interval >= _WAIT_INTERVAL_EPS, "[worker] wait_interval must be positive" assert self._config.wait_jitter >= 0.0, "[worker] wait_jitter must be equal or greater than zero" self._id = worker_id self._scheduler = scheduler self._assistant = assistant self._stop_requesting_work = False self.host = socket.gethostname() self._scheduled_tasks = {} self._suspended_tasks = {} self._batch_running_tasks = {} self._batch_families_sent = set() self._first_task = None self.add_succeeded = True self.run_succeeded = True self.unfulfilled_counts = collections.defaultdict(int) # note that ``signal.signal(signal.SIGUSR1, fn)`` only works inside the main execution thread, which is why we # provide the ability to conditionally install the hook. if not self._config.no_install_shutdown_handler: try: signal.signal(signal.SIGUSR1, self.handle_interrupt) signal.siginterrupt(signal.SIGUSR1, False) except AttributeError: pass # Keep info about what tasks are running (could be in other processes) self._task_result_queue = multiprocessing.Queue() self._running_tasks = {} self._idle_since = None # mp-safe dictionary for caching completation checks across task processes self._task_completion_cache = None if self._config.cache_task_completion: self._task_completion_cache = multiprocessing.Manager().dict() # Stuff for execution_summary self._add_task_history = [] self._get_work_response_history = [] def _add_task(self, *args, **kwargs): """ Call ``self._scheduler.add_task``, but store the values too so we can implement :py:func:`luigi.execution_summary.summary`. """ task_id = kwargs["task_id"] status = kwargs["status"] runnable = kwargs["runnable"] task = self._scheduled_tasks.get(task_id) if task: self._add_task_history.append((task, status, runnable)) kwargs["owners"] = task._owner_list() if task_id in self._batch_running_tasks: for batch_task in self._batch_running_tasks.pop(task_id): self._add_task_history.append((batch_task, status, True)) if task and kwargs.get("params"): kwargs["param_visibilities"] = task._get_param_visibilities() self._scheduler.add_task(*args, **kwargs) logger.info("Informed scheduler that task %s has status %s", task_id, status) def __enter__(self): """ Start the KeepAliveThread. """ self._keep_alive_thread = KeepAliveThread(self._scheduler, self._id, self._config.ping_interval, self._handle_rpc_message) self._keep_alive_thread.daemon = True self._keep_alive_thread.start() return self def __exit__(self, type, value, traceback): """ Stop the KeepAliveThread and kill still running tasks. """ self._keep_alive_thread.stop() self._keep_alive_thread.join() for task in self._running_tasks.values(): if task.is_alive(): task.terminate() self._task_result_queue.close() return False # Don't suppress exception def _generate_worker_info(self): # Generate as much info as possible about the worker # Some of these calls might not be available on all OS's args = [("salt", "%09d" % random.randrange(0, 10_000_000_000)), ("workers", self.worker_processes)] try: args += [("host", socket.gethostname())] except BaseException: pass try: args += [("username", getpass.getuser())] except BaseException: pass try: args += [("pid", os.getpid())] except BaseException: pass try: sudo_user = os.getenv("SUDO_USER") if sudo_user: args.append(("sudo_user", sudo_user)) except BaseException: pass return args def _generate_worker_id(self, worker_info): worker_info_str = ", ".join(["{}={}".format(k, v) for k, v in worker_info]) return "Worker({})".format(worker_info_str) def _validate_task(self, task): if not isinstance(task, Task): raise TaskException("Can not schedule non-task %s" % task) if not task.initialized(): # we can't get the repr of it since it's not initialized... raise TaskException("Task of class %s not initialized. Did you override __init__ and forget to call super(...).__init__?" % task.__class__.__name__) def _log_complete_error(self, task, tb): log_msg = "Will not run {task} or any dependencies due to error in complete() method:\n{tb}".format(task=task, tb=tb) logger.warning(log_msg) def _log_dependency_error(self, task, tb): log_msg = "Will not run {task} or any dependencies due to error in deps() method:\n{tb}".format(task=task, tb=tb) logger.warning(log_msg) def _log_unexpected_error(self, task): logger.exception("Luigi unexpected framework error while scheduling %s", task) # needs to be called from within except clause def _announce_scheduling_failure(self, task, expl): try: self._scheduler.announce_scheduling_failure( worker=self._id, task_name=str(task), family=task.task_family, params=task.to_str_params(only_significant=True), expl=expl, owners=task._owner_list(), ) except Exception: formatted_traceback = traceback.format_exc() self._email_unexpected_error(task, formatted_traceback) raise def _email_complete_error(self, task, formatted_traceback): self._announce_scheduling_failure(task, formatted_traceback) if self._config.send_failure_email: self._email_error( task, formatted_traceback, subject="Luigi: {task} failed scheduling. Host: {host}", headline="Will not run {task} or any dependencies due to error in complete() method", ) def _email_dependency_error(self, task, formatted_traceback): self._announce_scheduling_failure(task, formatted_traceback) if self._config.send_failure_email: self._email_error( task, formatted_traceback, subject="Luigi: {task} failed scheduling. Host: {host}", headline="Will not run {task} or any dependencies due to error in deps() method", ) def _email_unexpected_error(self, task, formatted_traceback): # this sends even if failure e-mails are disabled, as they may indicate # a more severe failure that may not reach other alerting methods such # as scheduler batch notification self._email_error( task, formatted_traceback, subject="Luigi: Framework error while scheduling {task}. Host: {host}", headline="Luigi framework error", ) def _email_task_failure(self, task, formatted_traceback): if self._config.send_failure_email: self._email_error( task, formatted_traceback, subject="Luigi: {task} FAILED. Host: {host}", headline="A task failed when running. Most likely run() raised an exception.", ) def _email_error(self, task, formatted_traceback, subject, headline): formatted_subject = subject.format(task=task, host=self.host) formatted_headline = headline.format(task=task, host=self.host) command = subprocess.list2cmdline(sys.argv) message = notifications.format_task_error(formatted_headline, task, command, formatted_traceback) notifications.send_error_email(formatted_subject, message, task.owner_email) def _handle_task_load_error(self, exception, task_ids): msg = "Cannot find task(s) sent by scheduler: {}".format(",".join(task_ids)) logger.exception(msg) subject = "Luigi: {}".format(msg) error_message = notifications.wrap_traceback(exception) for task_id in task_ids: self._add_task( worker=self._id, task_id=task_id, status=FAILED, runnable=False, expl=error_message, ) notifications.send_error_email(subject, error_message) def add(self, task, multiprocess=False, processes=0): """ Add a Task for the worker to check and possibly schedule and run. Returns True if task and its dependencies were successfully scheduled or completed before. """ if self._first_task is None and hasattr(task, "task_id"): self._first_task = task.task_id self.add_succeeded = True if multiprocess: queue = multiprocessing.Manager().Queue() pool = multiprocessing.Pool(processes=processes if processes > 0 else None) else: queue = DequeQueue() pool = SingleProcessPool() self._validate_task(task) pool.apply_async(check_complete, [task, queue, self._task_completion_cache]) # we track queue size ourselves because len(queue) won't work for multiprocessing queue_size = 1 try: seen = {task.task_id} while queue_size: current = queue.get() queue_size -= 1 item, is_complete = current for next in self._add(item, is_complete): if next.task_id not in seen: self._validate_task(next) seen.add(next.task_id) pool.apply_async(check_complete, [next, queue, self._task_completion_cache]) queue_size += 1 except (KeyboardInterrupt, TaskException): raise except Exception as ex: self.add_succeeded = False formatted_traceback = traceback.format_exc() self._log_unexpected_error(task) task.trigger_event(Event.BROKEN_TASK, task, ex) self._email_unexpected_error(task, formatted_traceback) raise finally: pool.close() pool.join() return self.add_succeeded def _add_task_batcher(self, task): family = task.task_family if family not in self._batch_families_sent: task_class = type(task) batch_param_names = task_class.batch_param_names() if batch_param_names: self._scheduler.add_task_batcher( worker=self._id, task_family=family, batched_args=batch_param_names, max_batch_size=task.max_batch_size, ) self._batch_families_sent.add(family) def _add(self, task, is_complete): if self._config.task_limit is not None and len(self._scheduled_tasks) >= self._config.task_limit: logger.warning("Will not run %s or any dependencies due to exceeded task-limit of %d", task, self._config.task_limit) deps = None status = UNKNOWN runnable = False else: formatted_traceback = None try: self._check_complete_value(is_complete) except KeyboardInterrupt: raise except AsyncCompletionException as ex: formatted_traceback = ex.trace except BaseException: formatted_traceback = traceback.format_exc() if formatted_traceback is not None: self.add_succeeded = False self._log_complete_error(task, formatted_traceback) task.trigger_event(Event.DEPENDENCY_MISSING, task) self._email_complete_error(task, formatted_traceback) deps = None status = UNKNOWN runnable = False elif is_complete: deps = None status = DONE runnable = False task.trigger_event(Event.DEPENDENCY_PRESENT, task) elif _is_external(task): deps = None status = PENDING runnable = self._config.retry_external_tasks task.trigger_event(Event.DEPENDENCY_MISSING, task) logger.warning("Data for %s does not exist (yet?). The task is an external data dependency, so it cannot be run from this luigi process.", task) else: try: deps = task.deps() self._add_task_batcher(task) except Exception as ex: formatted_traceback = traceback.format_exc() self.add_succeeded = False self._log_dependency_error(task, formatted_traceback) task.trigger_event(Event.BROKEN_TASK, task, ex) self._email_dependency_error(task, formatted_traceback) deps = None status = UNKNOWN runnable = False else: status = PENDING runnable = True if task.disabled: status = DISABLED if deps: for d in deps: self._validate_dependency(d) task.trigger_event(Event.DEPENDENCY_DISCOVERED, task, d) yield d # return additional tasks to add deps = [d.task_id for d in deps] self._scheduled_tasks[task.task_id] = task self._add_task( worker=self._id, task_id=task.task_id, status=status, deps=deps, runnable=runnable, priority=task.priority, resources=task.process_resources(), params=task.to_str_params(), family=task.task_family, module=task.task_module, batchable=task.batchable, retry_policy_dict=_get_retry_policy_dict(task), accepts_messages=task.accepts_messages, ) def _validate_dependency(self, dependency): if isinstance(dependency, Target): raise Exception("requires() can not return Target objects. Wrap it in an ExternalTask class") elif not isinstance(dependency, Task): raise Exception("requires() must return Task objects but {} is a {}".format(dependency, type(dependency))) def _check_complete_value(self, is_complete): if is_complete not in (True, False): if isinstance(is_complete, TracebackWrapper): raise AsyncCompletionException(is_complete.trace) raise Exception("Return value of Task.complete() must be boolean (was %r)" % is_complete) def _add_worker(self): self._worker_info.append(("first_task", self._first_task)) self._scheduler.add_worker(self._id, self._worker_info) def _log_remote_tasks(self, get_work_response): logger.debug("Done") logger.debug("There are no more tasks to run at this time") if get_work_response.running_tasks: for r in get_work_response.running_tasks: logger.debug("%s is currently run by worker %s", r["task_id"], r["worker"]) elif get_work_response.n_pending_tasks: logger.debug("There are %s pending tasks possibly being run by other workers", get_work_response.n_pending_tasks) if get_work_response.n_unique_pending: logger.debug("There are %i pending tasks unique to this worker", get_work_response.n_unique_pending) if get_work_response.n_pending_last_scheduled: logger.debug("There are %i pending tasks last scheduled by this worker", get_work_response.n_pending_last_scheduled) def _get_work_task_id(self, get_work_response): if get_work_response.get("task_id") is not None: return get_work_response["task_id"] elif "batch_id" in get_work_response: try: task = load_task( module=get_work_response.get("task_module"), task_name=get_work_response["task_family"], params_str=get_work_response["task_params"], ) except Exception as ex: self._handle_task_load_error(ex, get_work_response["batch_task_ids"]) self.run_succeeded = False return None self._scheduler.add_task( worker=self._id, task_id=task.task_id, module=get_work_response.get("task_module"), family=get_work_response["task_family"], params=task.to_str_params(), status=RUNNING, batch_id=get_work_response["batch_id"], ) return task.task_id else: return None def _get_work(self): if self._stop_requesting_work: return GetWorkResponse(None, 0, 0, 0, 0, WORKER_STATE_DISABLED) if self.worker_processes > 0: logger.debug("Asking scheduler for work...") r = self._scheduler.get_work( worker=self._id, host=self.host, assistant=self._assistant, current_tasks=list(self._running_tasks.keys()), ) else: logger.debug("Checking if tasks are still pending") r = self._scheduler.count_pending(worker=self._id) running_tasks = r["running_tasks"] task_id = self._get_work_task_id(r) self._get_work_response_history.append( { "task_id": task_id, "running_tasks": running_tasks, } ) if task_id is not None and task_id not in self._scheduled_tasks: logger.info("Did not schedule %s, will load it dynamically", task_id) try: # TODO: we should obtain the module name from the server! self._scheduled_tasks[task_id] = load_task(module=r.get("task_module"), task_name=r["task_family"], params_str=r["task_params"]) except TaskClassException as ex: self._handle_task_load_error(ex, [task_id]) task_id = None self.run_succeeded = False if task_id is not None and "batch_task_ids" in r: batch_tasks = filter(None, [self._scheduled_tasks.get(batch_id) for batch_id in r["batch_task_ids"]]) self._batch_running_tasks[task_id] = batch_tasks return GetWorkResponse( task_id=task_id, running_tasks=running_tasks, n_pending_tasks=r["n_pending_tasks"], n_unique_pending=r["n_unique_pending"], # TODO: For a tiny amount of time (a month?) we'll keep forwards compatibility # That is you can user a newer client than server (Sep 2016) n_pending_last_scheduled=r.get("n_pending_last_scheduled", 0), worker_state=r.get("worker_state", WORKER_STATE_ACTIVE), ) def _run_task(self, task_id): if task_id in self._running_tasks: logger.debug("Got already running task id {} from scheduler, taking a break".format(task_id)) next(self._sleeper()) return task = self._scheduled_tasks[task_id] task_process = self._create_task_process(task) self._running_tasks[task_id] = task_process if task_process.use_multiprocessing: with fork_lock: task_process.start() else: # Run in the same process task_process.run() def _create_task_process(self, task): message_queue = multiprocessing.Queue() if task.accepts_messages else None reporter = TaskStatusReporter(self._scheduler, task.task_id, self._id, message_queue) use_multiprocessing = self._config.force_multiprocessing or bool(self.worker_processes > 1) return ContextManagedTaskProcess( self._config.task_process_context, task, self._id, self._task_result_queue, reporter, use_multiprocessing=use_multiprocessing, worker_timeout=self._config.timeout, check_unfulfilled_deps=self._config.check_unfulfilled_deps, check_complete_on_run=self._config.check_complete_on_run, task_completion_cache=self._task_completion_cache, ) def _purge_children(self): """ Find dead children and put a response on the result queue. :return: """ for task_id, p in self._running_tasks.items(): if not p.is_alive() and p.exitcode: error_msg = "Task {} died unexpectedly with exit code {}".format(task_id, p.exitcode) p.task.trigger_event(Event.PROCESS_FAILURE, p.task, error_msg) elif p.timeout_time is not None and time.time() > float(p.timeout_time) and p.is_alive(): p.terminate() error_msg = "Task {} timed out after {} seconds and was terminated.".format(task_id, p.worker_timeout) p.task.trigger_event(Event.TIMEOUT, p.task, error_msg) else: continue logger.info(error_msg) self._task_result_queue.put((task_id, FAILED, error_msg, [], [])) def _handle_next_task(self): """ We have to catch three ways a task can be "done": 1. normal execution: the task runs/fails and puts a result back on the queue, 2. new dependencies: the task yielded new deps that were not complete and will be rescheduled and dependencies added, 3. child process dies: we need to catch this separately. """ self._idle_since = None while True: self._purge_children() # Deal with subprocess failures try: task_id, status, expl, missing, new_requirements = self._task_result_queue.get(timeout=self._config.wait_interval) except Queue.Empty: return task = self._scheduled_tasks[task_id] if not task or task_id not in self._running_tasks: continue # Not a running task. Probably already removed. # Maybe it yielded something? # external task if run not implemented, retry-able if config option is enabled. external_task_retryable = _is_external(task) and self._config.retry_external_tasks if status == FAILED and not external_task_retryable: self._email_task_failure(task, expl) new_deps = [] if new_requirements: new_req = [load_task(module, name, params) for module, name, params in new_requirements] for t in new_req: self.add(t) new_deps = [t.task_id for t in new_req] self._add_task( worker=self._id, task_id=task_id, status=status, expl=json.dumps(expl), resources=task.process_resources(), runnable=None, params=task.to_str_params(), family=task.task_family, module=task.task_module, new_deps=new_deps, assistant=self._assistant, retry_policy_dict=_get_retry_policy_dict(task), ) self._running_tasks.pop(task_id) # re-add task to reschedule missing dependencies if missing: reschedule = True # keep out of infinite loops by not rescheduling too many times for task_id in missing: self.unfulfilled_counts[task_id] += 1 if self.unfulfilled_counts[task_id] > self._config.max_reschedules: reschedule = False if reschedule: self.add(task) self.run_succeeded &= (status == DONE) or (len(new_deps) > 0) return def _sleeper(self): # TODO is exponential backoff necessary? while True: jitter = self._config.wait_jitter wait_interval = self._config.wait_interval + random.uniform(0, jitter) logger.debug("Sleeping for %f seconds", wait_interval) time.sleep(wait_interval) yield def _keep_alive(self, get_work_response): """ Returns true if a worker should stay alive given. If worker-keep-alive is not set, this will always return false. For an assistant, it will always return the value of worker-keep-alive. Otherwise, it will return true for nonzero n_pending_tasks. If worker-count-uniques is true, it will also require that one of the tasks is unique to this worker. """ if not self._config.keep_alive: return False elif self._assistant: return True elif self._config.count_last_scheduled: return get_work_response.n_pending_last_scheduled > 0 elif self._config.count_uniques: return get_work_response.n_unique_pending > 0 elif get_work_response.n_pending_tasks == 0: return False elif not self._config.max_keep_alive_idle_duration: return True elif not self._idle_since: return True else: time_to_shutdown = self._idle_since + self._config.max_keep_alive_idle_duration - datetime.datetime.now() logger.debug("[%s] %s until shutdown", self._id, time_to_shutdown) return time_to_shutdown > datetime.timedelta(0) def handle_interrupt(self, signum, _): """ Stops the assistant from asking for more work on SIGUSR1 """ if signum == signal.SIGUSR1: self._start_phasing_out() def _start_phasing_out(self): """ Go into a mode where we dont ask for more work and quit once existing tasks are done. """ self._config.keep_alive = False self._stop_requesting_work = True def run(self): """ Returns True if all scheduled tasks were executed successfully. """ logger.info("Running Worker with %d processes", self.worker_processes) sleeper = self._sleeper() self.run_succeeded = True self._add_worker() while True: while len(self._running_tasks) >= self.worker_processes > 0: logger.debug("%d running tasks, waiting for next task to finish", len(self._running_tasks)) self._handle_next_task() get_work_response = self._get_work() if get_work_response.worker_state == WORKER_STATE_DISABLED: self._start_phasing_out() if get_work_response.task_id is None: if not self._stop_requesting_work: self._log_remote_tasks(get_work_response) if len(self._running_tasks) == 0: self._idle_since = self._idle_since or datetime.datetime.now() if self._keep_alive(get_work_response): next(sleeper) continue else: break else: self._handle_next_task() continue # task_id is not None: logger.debug("Pending tasks: %s", get_work_response.n_pending_tasks) self._run_task(get_work_response.task_id) while len(self._running_tasks): logger.debug("Shut down Worker, %d more tasks to go", len(self._running_tasks)) self._handle_next_task() return self.run_succeeded def _handle_rpc_message(self, message): logger.info("Worker %s got message %s" % (self._id, message)) # the message is a dict {'name': , 'kwargs': } name = message["name"] kwargs = message["kwargs"] # find the function and check if it's callable and configured to work # as a message callback func = getattr(self, name, None) tpl = (self._id, name) if not callable(func): logger.error("Worker %s has no function '%s'" % tpl) elif not getattr(func, "is_rpc_message_callback", False): logger.error("Worker %s function '%s' is not available as rpc message callback" % tpl) else: logger.info("Worker %s successfully dispatched rpc message to function '%s'" % tpl) func(**kwargs) @rpc_message_callback def set_worker_processes(self, n): # set the new value self.worker_processes = max(1, n) # tell the scheduler self._scheduler.add_worker(self._id, {"workers": self.worker_processes}) @rpc_message_callback def dispatch_scheduler_message(self, task_id, message_id, content, **kwargs): task_id = str(task_id) if task_id in self._running_tasks: task_process = self._running_tasks[task_id] if task_process.status_reporter.scheduler_messages: message = SchedulerMessage(self._scheduler, task_id, message_id, content, **kwargs) task_process.status_reporter.scheduler_messages.put(message) ================================================ FILE: pyproject.toml ================================================ [build-system] requires = ['hatchling', 'hatch-fancy-pypi-readme'] build-backend = 'hatchling.build' [project] name = "luigi" description = "Workflow mgmgt + task scheduling + dependency resolution." authors = [ {name = "The Luigi Authors"} ] license = {file = "LICENSE"} requires-python = ">=3.10, <3.14" dependencies = [ "python-dateutil>=2.7.5,<3", "tenacity>=9", "tornado>=5.0,<7", "python-daemon<2.2.0; sys_platform == 'win32'", "python-daemon; sys_platform != 'win32'", "typing-extensions>=4.12.2", ] classifiers = [ "Development Status :: 5 - Production/Stable", "Environment :: Console", "Environment :: Web Environment", "Intended Audience :: Developers", "Intended Audience :: System Administrators", "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", "Topic :: System :: Monitoring", ] dynamic = ["version", "readme"] [project.urls] Homepage = "https://github.com/spotify/luigi" [project.scripts] luigi = "luigi.cmdline:luigi_run" luigid = "luigi.cmdline:luigid" luigi-grep = "luigi.tools.luigi_grep:main" luigi-deps = "luigi.tools.deps:main" luigi-deps-tree = "luigi.tools.deps_tree:main" [project.optional-dependencies] jsonschema = ["jsonschema"] prometheus = ["prometheus-client>=0.5,<0.25"] toml = ["toml<2.0.0"] [dependency-groups] # groups and dependencies should be sort in lexicographical order cdh = [ "hdfs>=2.0.4,<3.0.0", ] common = [ "avro-python3", "azure-storage-blob<=12.20.0", "boto>=2.42,<3.0", "boto3>=1.11.0", "codecov>=1.4.0", "coverage>=5.0,<6", "datadog==0.22.0", "docker>=2.1.0", "elasticsearch>=1.0.0,<2.0.0", "google-compute-engine", "HTTPretty==0.8.10", "hypothesis>=6.7.0,<7.0.0", "jsonschema", "mock<2.0", "moto>=1.3.10,<5.0", "mypy", "mysql-connector-python", "prometheus-client>=0.5.0,<0.25", "psutil<4.0", "pygments", "pyhive[presto]==0.6.1", "pymongo==3.4.0", "pytest", "pytest-cov", "pytest-xdist", "requests>=2.20.0,<=2.31.0", "responses<1.0.0", "s3transfer>=0.3,<4.0", "selenium==3.0.2", "sqlalchemy<1.4", "toml<2.0.0", "types-python-dateutil", "types-requests", "types-toml", ] docs = [ "azure-storage-blob<=12.28.0", "jinja2>=3.1,<4", "mypy", "prometheus-client>=0.5.0,<0.25", "Sphinx>=9.0,<10; python_version >= '3.12'", "sphinx-rtd-theme>=2.0; python_version >= '3.12'", "sqlalchemy", ] dropbox = [ "dropbox>=11.0.0", ] gcloud = [ "google-api-python-client>=1.6.6,<2.0", "google-auth==1.4.1", "google-auth-httplib2==0.0.3", ] hdp = [ "hdfs>=2.0.4,<3.0.0", ] lint = [ "ruff", ] postgres = [ "pg8000>=1.23.0", "psycopg2<3.0", ] unixsocket = [ "requests-unixsocket<1.0", ] # for tox test dependencies test_cdh = [ {include-group = "cdh"}, {include-group = "common"}, ] test_dropbox = [ {include-group = "dropbox"}, {include-group = "common"}, ] test_gcloud = [ {include-group = "gcloud"}, {include-group = "common"}, ] test_hdp = [ {include-group = "hdp"}, {include-group = "common"}, ] test_postgres = [ {include-group = "postgres"}, {include-group = "common"}, ] test_unixsocket = [ {include-group = "unixsocket"}, {include-group = "common"}, ] visualizer = [ "mock<2.0", "selenium==3.0.2" ] # for local development dev = [ {include-group = "gcloud"}, {include-group = "postgres"}, {include-group = "dropbox"}, {include-group = "cdh"}, # same deps as hdp {include-group = "unixsocket"}, {include-group = "common"}, {include-group = "lint"}, ] [tool.mypy] # Keep this set to the minimum supported Python version (see requires-python in [project]) python_version = "3.10" ignore_missing_imports = true # Gradually tighten: remove a module from the ignore list below after fixing its errors [[tool.mypy.overrides]] module = [ "luigi.contrib.gcs", "luigi.contrib.hadoop", "luigi.contrib.hdfs.config", "luigi.contrib.postgres", "luigi.contrib.redis_store", "luigi.contrib.spark", "luigi.contrib.sqla", "luigi.interface", "luigi.notifications", "luigi.tools.range", "luigi.worker", ] ignore_errors = true [tool.ruff] line-length = 160 exclude = ["doc"] [tool.ruff.lint] select = [ "E", # pycodestyle errors "F", # pyflakes "I", # isort "W", # pycodestyle warnings ] [tool.ruff.lint.isort] known-first-party = ["luigi"] [tool.uv] default-groups = ['dev'] cache-keys = [ { file = "pyproject.toml" }, { git = true } ] [tool.hatch.version] path = "luigi/__version__.py" [tool.hatch.build.targets.sdist] include = [ "/LICENSE", "/README.rst", "/examples", "/luigi", "/test", ] [tool.hatch.metadata.hooks.fancy-pypi-readme] content-type = "text/x-rst" # construct the PyPI readme from README.md and HISTORY.md fragments = [ {text = "\n.. note::\n\tFor the latest source, discussion, etc, please visit the\n\t`GitHub repository `_\n"}, {path = "README.rst"}, ] ================================================ FILE: scripts/ci/conditional_tox.sh ================================================ #!/usr/bin/env bash set -ex ENDENV=$(echo $TOXENV | tail -c 7) if [[ $ENDENV == gcloud ]] then [[ $DIDNT_CREATE_GCP_CREDS = 1 ]] || tox else tox --hashseed 1 fi ================================================ FILE: scripts/ci/install_start_azurite.sh ================================================ #!/usr/bin/env bash echo "$DOCKERHUB_TOKEN" | docker login -u spotifyci --password-stdin docker pull mcr.microsoft.com/azure-storage/azurite mkdir -p blob_emulator $1/stop_azurite.sh docker run -p 10000:10000 -v blob_emulator:/data -e AZURITE_ACCOUNTS=devstoreaccount1:YXp1cml0ZQ== -d mcr.microsoft.com/azure-storage/azurite azurite-blob -l /data --blobHost 0.0.0.0 --blobPort 10000 ================================================ FILE: scripts/ci/setup_hadoop_env.sh ================================================ #!/usr/bin/env bash HADOOP_DISTRO=${HADOOP_DISTRO:-"hdp"} ONLY_DOWNLOAD=${ONLY_DOWNLOAD:-false} ONLY_EXTRACT=${ONLY_EXTRACT:-false} while test $# -gt 0; do case "$1" in -h|--help) echo "Setup environment for snakebite tests" echo " " echo "options:" echo -e "\t-h, --help show brief help" echo -e "\t-o, --only-download just download hadoop tar(s)" echo -e "\t-e, --only-extract just extract hadoop tar(s)" echo -e "\t-d, --distro select distro (hdp|cdh)" exit 0 ;; -o|--only-download) shift ONLY_DOWNLOAD=true ;; -e|--only-extract) shift ONLY_EXTRACT=true ;; -d|--distro) shift if test $# -gt 0; then HADOOP_DISTRO=$1 else echo "No Hadoop distro specified - abort" >&2 exit 1 fi shift ;; *) echo "Unknown options: $1" >&2 exit 1 ;; esac done if $ONLY_DOWNLOAD && $ONLY_EXTRACT; then echo "Both only-download and only-extract specified - abort" >&2 exit 1 fi mkdir -p $HADOOP_HOME if [ $HADOOP_DISTRO = "cdh" ]; then URL="http://archive.cloudera.com/cdh5/cdh/5/hadoop-latest.tar.gz" elif [ $HADOOP_DISTRO = "hdp" ]; then # This site provides good URLs: # https://github.com/saltstack-formulas/hadoop-formula/blob/5034a2204da691eceb9c2d8cd8260f11d5cc06f3/hadoop/settings.sls URL="http://public-repo-1.hortonworks.com/HDP/centos6/2.x/updates/2.2.6.0/tars/hadoop-2.6.0.2.2.6.0-2800.tar.gz" else echo "No/bad HADOOP_DISTRO='${HADOOP_DISTRO}' specified" >&2 exit 1 fi if ! $ONLY_EXTRACT && [ ! -e ${HADOOP_HOME}/hadoop.tar.gz ] ; then echo "Downloading Hadoop from $URL to ${HADOOP_HOME}/hadoop.tar.gz" curl -z ${HADOOP_HOME}/hadoop.tar.gz -o ${HADOOP_HOME}/hadoop.tar.gz -L $URL if [ $? != 0 ]; then echo "Failed to download Hadoop from $URL - abort" >&2 exit 1 fi fi if $ONLY_DOWNLOAD; then exit 0 fi echo "Extracting ${HADOOP_HOME}/hadoop.tar.gz into $HADOOP_HOME" tar zxf ${HADOOP_HOME}/hadoop.tar.gz --strip-components 1 -C $HADOOP_HOME if [ $? != 0 ]; then echo "Failed to extract Hadoop from ${HADOOP_HOME}/hadoop.tar.gz to ${HADOOP_HOME} - abort" >&2 exit 1 fi ================================================ FILE: scripts/ci/stop_azurite.sh ================================================ #!/usr/bin/env bash docker stop "$(docker ps -q --filter ancestor=mcr.microsoft.com/azure-storage/azurite)" ================================================ FILE: test/_mysqldb_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import mysql.connector from helpers import unittest from luigi.contrib.mysqldb import MySqlTarget host = "localhost" port = 3306 database = "luigi_test" username = None password = None table_updates = "table_updates" def _create_test_database(): con = mysql.connector.connect(user=username, password=password, host=host, port=port, autocommit=True) con.cursor().execute("CREATE DATABASE IF NOT EXISTS %s" % database) _create_test_database() target = MySqlTarget(host, database, username, password, "", "update_id") class MySqlTargetTest(unittest.TestCase): def test_touch_and_exists(self): drop() self.assertFalse(target.exists(), "Target should not exist before touching it") target.touch() self.assertTrue(target.exists(), "Target should exist after touching it") def drop(): con = target.connect(autocommit=True) con.cursor().execute("DROP TABLE IF EXISTS %s" % table_updates) ================================================ FILE: test/_test_ftp.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # this is an integration test. to run this test requires that an actuall FTP server # is running somewhere. to run a local ftp server do the following # pip install pyftpdlib==1.5.0 # mkdir /tmp/luigi-test-ftp/ # sudo python -m _test_ftp import datetime import ftplib import os import shutil import sys from io import StringIO from helpers import unittest from luigi.contrib.ftp import RemoteFileSystem, RemoteTarget # dumb files FILE1 = """this is file1""" FILE2 = """this is file2""" FILE3 = """this is file3""" HOST = "localhost" USER = "luigi" PWD = "some_password" class TestFTPFilesystem(unittest.TestCase): def setUp(self): """Creates structure /test /test/file1 /test/hola/ /test/hola/file2 /test/hola/singlefile /test/hola/file3 """ # create structure ftp = ftplib.FTP(HOST, USER, PWD) ftp.cwd("/") ftp.mkd("test") ftp.cwd("test") ftp.mkd("hola") ftp.cwd("hola") f2 = StringIO(FILE2) ftp.storbinary("STOR file2", f2) # send the file f3 = StringIO(FILE3) ftp.storbinary("STOR file3", f3) # send the file ftp.cwd("..") f1 = StringIO(FILE1) ftp.storbinary("STOR file1", f1) # send the file ftp.close() def test_file_remove(self): """Delete with recursive deactivated""" rfs = RemoteFileSystem(HOST, USER, PWD) rfs.remove("/test/hola/file3", recursive=False) rfs.remove("/test/hola/file2", recursive=False) rfs.remove("/test/hola", recursive=False) rfs.remove("/test/file1", recursive=False) rfs.remove("/test", recursive=False) ftp = ftplib.FTP(HOST, USER, PWD) list_dir = ftp.nlst() self.assertFalse("test" in list_dir) def test_recursive_remove(self): """Test FTP filesystem removing files recursive""" rfs = RemoteFileSystem(HOST, USER, PWD) rfs.remove("/test") ftp = ftplib.FTP(HOST, USER, PWD) list_dir = ftp.nlst() self.assertFalse("test" in list_dir) class TestFTPFilesystemUpload(unittest.TestCase): def test_single(self): """Test upload file with creation of intermediate folders""" ftp_path = "/test/nest/luigi-test" local_filepath = "/tmp/luigi-test-ftp" # create local temp file with open(local_filepath, "w") as outfile: outfile.write("something to fill") rfs = RemoteFileSystem(HOST, USER, PWD) rfs.put(local_filepath, ftp_path) # manually connect to ftp ftp = ftplib.FTP(HOST, USER, PWD) ftp.cwd("/test/nest") list_dir = ftp.nlst() # file is successfuly created self.assertTrue("luigi-test" in list_dir) # delete tmp files ftp.delete("luigi-test") ftp.cwd("/") ftp.rmd("/test/nest") ftp.rmd("test") os.remove(local_filepath) ftp.close() class TestRemoteTarget(unittest.TestCase): def test_put(self): """Test RemoteTarget put method with uploading to an FTP""" local_filepath = "/tmp/luigi-remotetarget-write-test" remote_file = "/test/example.put.file" # create local temp file with open(local_filepath, "w") as outfile: outfile.write("something to fill") remotetarget = RemoteTarget(remote_file, HOST, username=USER, password=PWD) remotetarget.put(local_filepath) # manually connect to ftp ftp = ftplib.FTP(HOST, USER, PWD) ftp.cwd("/test") list_dir = ftp.nlst() # file is successfuly created self.assertTrue(remote_file.split("/")[-1] in list_dir) # clean os.remove(local_filepath) ftp.delete(remote_file) ftp.cwd("/") ftp.rmd("test") ftp.close() def test_get(self): """Test Remote target get method downloading a file from ftp""" local_filepath = "/tmp/luigi-remotetarget-read-test" tmp_filepath = "/tmp/tmp-luigi-remotetarget-read-test" remote_file = "/test/example.get.file" # create local temp file with open(tmp_filepath, "w") as outfile: outfile.write("something to fill") # manualy upload to ftp ftp = ftplib.FTP(HOST, USER, PWD) ftp.mkd("test") ftp.storbinary("STOR %s" % remote_file, open(tmp_filepath, "rb")) ftp.close() # execute command remotetarget = RemoteTarget(remote_file, HOST, username=USER, password=PWD) remotetarget.get(local_filepath) # make sure that it can open file with remotetarget.open("r") as fin: self.assertEqual(fin.read(), "something to fill") # check for cleaning temporary files if sys.version_info >= (3, 2): # cleanup uses tempfile.TemporaryDirectory only available in 3.2+ temppath = remotetarget._RemoteTarget__tmp_path self.assertTrue(os.path.exists(temppath)) remotetarget = None # garbage collect remotetarget self.assertFalse(os.path.exists(temppath)) # file is successfuly created self.assertTrue(os.path.exists(local_filepath)) # test RemoteTarget with mtime ts = datetime.datetime.now() - datetime.timedelta(days=2) delayed_remotetarget = RemoteTarget(remote_file, HOST, username=USER, password=PWD, mtime=ts) self.assertTrue(delayed_remotetarget.exists()) ts = datetime.datetime.now() + datetime.timedelta(days=2) # who knows what timezone it is in delayed_remotetarget = RemoteTarget(remote_file, HOST, username=USER, password=PWD, mtime=ts) self.assertFalse(delayed_remotetarget.exists()) # clean os.remove(local_filepath) os.remove(tmp_filepath) ftp = ftplib.FTP(HOST, USER, PWD) ftp.delete(remote_file) ftp.cwd("/") ftp.rmd("test") ftp.close() def _run_ftp_server(): from pyftpdlib.authorizers import DummyAuthorizer from pyftpdlib.handlers import FTPHandler from pyftpdlib.servers import FTPServer # Instantiate a dummy authorizer for managing 'virtual' users authorizer = DummyAuthorizer() tmp_folder = "/tmp/luigi-test-ftp-server/" if os.path.exists(tmp_folder): shutil.rmtree(tmp_folder) os.mkdir(tmp_folder) authorizer.add_user(USER, PWD, tmp_folder, perm="elradfmwM") handler = FTPHandler handler.authorizer = authorizer address = ("localhost", 21) server = FTPServer(address, handler) server.serve_forever() if __name__ == "__main__": _run_ftp_server() ================================================ FILE: test/auto_namespace_test/__init__.py ================================================ import luigi luigi.auto_namespace(scope=__name__) ================================================ FILE: test/auto_namespace_test/my_namespace_test.py ================================================ from helpers import LuigiTestCase import luigi class MyNamespaceTest(LuigiTestCase): def test_auto_namespace_scope(self): class MyTask(luigi.Task): pass self.assertTrue(self.run_locally(["auto_namespace_test.my_namespace_test.MyTask"])) self.assertEqual(MyTask.get_task_namespace(), "auto_namespace_test.my_namespace_test") ================================================ FILE: test/batch_notifier_test.py ================================================ # coding=utf-8 import unittest from smtplib import SMTPServerDisconnected import mock import luigi.batch_notifier BATCH_NOTIFIER_DEFAULTS = { "error_lines": 0, "error_messages": 0, "group_by_error_messages": False, } class BatchNotifier(luigi.batch_notifier.BatchNotifier): """BatchNotifier class with defaults that produce smaller output for testing""" def __init__(self, **kwargs): full_args = BATCH_NOTIFIER_DEFAULTS.copy() full_args.update(kwargs) super(BatchNotifier, self).__init__(**full_args) class BatchNotifierTest(unittest.TestCase): def setUp(self): self.time_mock = mock.patch("luigi.batch_notifier.time.time") self.time = self.time_mock.start() self.time.return_value = 0.0 self.send_email_mock = mock.patch("luigi.batch_notifier.send_email") self.send_email = self.send_email_mock.start() self.email_mock = mock.patch("luigi.batch_notifier.email") self.email = self.email_mock.start() self.email().sender = "sender@test.com" self.email().receiver = "r@test.com" def tearDown(self): self.time_mock.stop() self.send_email_mock.stop() self.email_mock.stop() def incr_time(self, minutes): self.time.return_value += minutes * 60 def check_email_send(self, subject, message, receiver="r@test.com", sender="sender@test.com"): self.send_email.assert_called_once_with(subject, message, sender, (receiver,)) def test_send_single_failure(self): bn = BatchNotifier(batch_mode="all") bn.add_failure("Task(a=5)", "Task", {"a": 5}, "error", []) bn.send_email() self.check_email_send("Luigi: 1 failure in the last 60 minutes", "- Task(a=5) (1 failure)") def test_do_not_send_single_failure_without_receiver(self): self.email().receiver = None bn = BatchNotifier(batch_mode="all") bn.add_failure("Task(a=5)", "Task", {"a": 5}, "error", []) bn.send_email() self.send_email.assert_not_called() def test_send_single_failure_to_owner_only(self): self.email().receiver = None bn = BatchNotifier(batch_mode="all") bn.add_failure("Task(a=5)", "Task", {"a": 5}, "error", ["owner@test.com"]) bn.send_email() self.check_email_send( "Luigi: Your tasks have 1 failure in the last 60 minutes", "- Task(a=5) (1 failure)", receiver="owner@test.com", ) def test_send_single_disable(self): bn = BatchNotifier(batch_mode="all") for _ in range(10): bn.add_failure("Task(a=5)", "Task", {"a": 5}, "error", []) bn.add_disable("Task(a=5)", "Task", {"a": 5}, []) bn.send_email() self.check_email_send("Luigi: 10 failures, 1 disable in the last 60 minutes", "- Task(a=5) (10 failures, 1 disable)") def test_send_multiple_disables(self): bn = BatchNotifier(batch_mode="family") for _ in range(10): bn.add_failure("Task(a=5)", "Task", {"a": 5}, "error", []) bn.add_failure("Task(a=6)", "Task", {"a": 6}, "error", []) bn.add_disable("Task(a=5)", "Task", {"a": 5}, []) bn.add_disable("Task(a=6)", "Task", {"a": 6}, []) bn.send_email() self.check_email_send("Luigi: 20 failures, 2 disables in the last 60 minutes", "- Task (20 failures, 2 disables)") def test_send_single_scheduling_fail(self): bn = BatchNotifier(batch_mode="family") bn.add_scheduling_fail("Task()", "Task", {}, "error", []) bn.send_email() self.check_email_send( "Luigi: 1 scheduling failure in the last 60 minutes", "- Task (1 scheduling failure)", ) def test_multiple_failures_of_same_job(self): bn = BatchNotifier(batch_mode="all") bn.add_failure("Task(a=5)", "Task", {"a": 5}, "error", []) bn.add_failure("Task(a=5)", "Task", {"a": 5}, "error", []) bn.add_failure("Task(a=5)", "Task", {"a": 5}, "error", []) bn.send_email() self.check_email_send("Luigi: 3 failures in the last 60 minutes", "- Task(a=5) (3 failures)") def test_multiple_failures_of_multiple_jobs(self): bn = BatchNotifier(batch_mode="all") bn.add_failure("Task(a=5)", "Task", {"a": 5}, "error", []) bn.add_failure("Task(a=6)", "Task", {"a": 6}, "error", []) bn.add_failure("Task(a=6)", "Task", {"a": 6}, "error", []) bn.send_email() self.check_email_send("Luigi: 3 failures in the last 60 minutes", "- Task(a=6) (2 failures)\n- Task(a=5) (1 failure)") def test_group_on_family(self): bn = BatchNotifier(batch_mode="family") bn.add_failure("Task(a=5)", "Task", {"a": 5}, "error", []) bn.add_failure("Task(a=6)", "Task", {"a": 6}, "error", []) bn.add_failure("Task(a=6)", "Task", {"a": 6}, "error", []) bn.add_failure("OtherTask(a=6)", "OtherTask", {"a": 6}, "error", []) bn.send_email() self.check_email_send("Luigi: 4 failures in the last 60 minutes", "- Task (3 failures)\n- OtherTask (1 failure)") def test_group_on_unbatched_params(self): bn = BatchNotifier(batch_mode="unbatched_params") bn.add_failure("Task(a=5, b=1)", "Task", {"a": 5}, "error", []) bn.add_failure("Task(a=5, b=2)", "Task", {"a": 5}, "error", []) bn.add_failure("Task(a=6, b=1)", "Task", {"a": 6}, "error", []) bn.add_failure("Task(a=6, b=2)", "Task", {"a": 6}, "error", []) bn.add_failure("Task(a=6, b=3)", "Task", {"a": 6}, "error", []) bn.add_failure("Task(a=6, b=4)", "Task", {"a": 6}, "error", []) bn.add_failure("OtherTask(a=5, b=1)", "OtherTask", {"a": 5}, "error", []) bn.add_failure("OtherTask(a=6, b=1)", "OtherTask", {"a": 6}, "error", []) bn.add_failure("OtherTask(a=6, b=2)", "OtherTask", {"a": 6}, "error", []) bn.add_failure("OtherTask(a=6, b=3)", "OtherTask", {"a": 6}, "error", []) bn.send_email() self.check_email_send( "Luigi: 10 failures in the last 60 minutes", "- Task(a=6) (4 failures)\n- OtherTask(a=6) (3 failures)\n- Task(a=5) (2 failures)\n- OtherTask(a=5) (1 failure)", ) def test_include_one_expl_includes_latest(self): bn = BatchNotifier(batch_mode="family", error_messages=1) bn.add_failure("Task(a=1)", "Task", {"a": 1}, "error 1", []) bn.add_failure("Task(a=2)", "Task", {"a": 2}, "error 2", []) bn.add_failure("TaskB(a=1)", "TaskB", {"a": 1}, "error", []) bn.send_email() self.check_email_send("Luigi: 3 failures in the last 60 minutes", "- Task (2 failures)\n\n error 2\n\n- TaskB (1 failure)\n\n error") def test_include_two_expls(self): bn = BatchNotifier(batch_mode="family", error_messages=2) bn.add_failure("Task(a=1)", "Task", {"a": 1}, "error 1", []) bn.add_failure("Task(a=2)", "Task", {"a": 2}, "error 2", []) bn.add_failure("TaskB(a=1)", "TaskB", {"a": 1}, "error", []) bn.send_email() self.check_email_send( "Luigi: 3 failures in the last 60 minutes", "- Task (2 failures)\n\n error 1\n\n error 2\n\n- TaskB (1 failure)\n\n error" ) def test_limit_expl_length(self): bn = BatchNotifier(batch_mode="family", error_messages=1, error_lines=2) bn.add_failure("Task(a=1)", "Task", {"a": "1"}, "line 1\nline 2\nline 3\nline 4\n", []) bn.send_email() self.check_email_send("Luigi: 1 failure in the last 60 minutes", "- Task (1 failure)\n\n line 3\n line 4") def test_expl_varies_by_owner(self): bn = BatchNotifier(batch_mode="family", error_messages=1) bn.add_failure("Task(a=1)", "Task", {"a": "1"}, "msg1", owners=["a@test.com"]) bn.add_failure("Task(a=2)", "Task", {"a": "2"}, "msg2", owners=["b@test.com"]) bn.send_email() send_calls = [ mock.call( "Luigi: Your tasks have 1 failure in the last 60 minutes", "- Task (1 failure)\n\n msg1", "sender@test.com", ("a@test.com",), ), mock.call( "Luigi: Your tasks have 1 failure in the last 60 minutes", "- Task (1 failure)\n\n msg2", "sender@test.com", ("b@test.com",), ), mock.call( "Luigi: 2 failures in the last 60 minutes", "- Task (2 failures)\n\n msg2", "sender@test.com", ("r@test.com",), ), ] self.send_email.assert_has_calls(send_calls, any_order=True) def test_include_two_expls_html_format(self): self.email().format = "html" bn = BatchNotifier(batch_mode="family", error_messages=2) bn.add_failure("Task(a=1)", "Task", {"a": 1}, "error 1", []) bn.add_failure("Task(a=2)", "Task", {"a": 2}, "error 2", []) bn.add_failure("TaskB(a=1)", "TaskB", {"a": 1}, "error", []) bn.send_email() self.check_email_send( "Luigi: 3 failures in the last 60 minutes", "
      \n
    • Task (2 failures)\n
      error 1
      \n
      error 2
      \n
    • TaskB (1 failure)\n
      error
      \n
    ", ) def test_limit_expl_length_html_format(self): self.email().format = "html" bn = BatchNotifier(batch_mode="family", error_messages=1, error_lines=2) bn.add_failure("Task(a=1)", "Task", {"a": 1}, "line 1\nline 2\nline 3\nline 4\n", []) bn.send_email() self.check_email_send("Luigi: 1 failure in the last 60 minutes", "
      \n
    • Task (1 failure)\n
      line 3\nline 4
      \n
    ") def test_send_clears_backlog(self): bn = BatchNotifier(batch_mode="all") bn.add_failure("Task(a=5)", "Task", {"a": 5}, "error", []) bn.add_disable("Task(a=5)", "Task", {"a": 5}, []) bn.add_scheduling_fail("Task(a=6)", "Task", {"a": 6}, "scheduling error", []) bn.send_email() self.send_email.reset_mock() bn.send_email() self.send_email.assert_not_called() def test_email_gets_cleared_on_failure(self): bn = BatchNotifier(batch_mode="all") bn.add_failure("Task(a=5)", "Task", {"a": 1}, "", []) self.send_email.side_effect = SMTPServerDisconnected("timeout") self.assertRaises(SMTPServerDisconnected, bn.send_email) self.send_email.reset_mock() bn.send_email() self.send_email.assert_not_called() def test_send_clears_all_old_data(self): bn = BatchNotifier(batch_mode="all", error_messages=100) for i in range(100): bn.add_failure("Task(a=5)", "Task", {"a": 5}, "error {}".format(i), []) bn.add_disable("Task(a=5)", "Task", {"a": 5}, []) bn.add_scheduling_fail("Task(a=6)", "Task", {"a": 6}, "scheduling error {}".format(i), []) bn.send_email() self.check_email_send( "Luigi: 1 failure, 1 disable, 1 scheduling failure in the last 60 minutes", "- Task(a=5) (1 failure, 1 disable)\n\n error {}\n\n- Task(a=6) (1 scheduling failure)\n\n scheduling error {}".format(i, i), ) self.send_email.reset_mock() def test_auto_send_on_update_after_time_period(self): bn = BatchNotifier(batch_mode="all") bn.add_failure("Task(a=5)", "Task", {"a": 5}, "error", []) for i in range(60): bn.update() self.send_email.assert_not_called() self.incr_time(minutes=1) bn.update() self.check_email_send("Luigi: 1 failure in the last 60 minutes", "- Task(a=5) (1 failure)") def test_auto_send_on_update_after_time_period_with_disable_only(self): bn = BatchNotifier(batch_mode="all") bn.add_disable("Task(a=5)", "Task", {"a": 5}, []) for i in range(60): bn.update() self.send_email.assert_not_called() self.incr_time(minutes=1) bn.update() self.check_email_send("Luigi: 1 disable in the last 60 minutes", "- Task(a=5) (1 disable)") def test_no_auto_send_until_end_of_interval_with_error(self): bn = BatchNotifier(batch_mode="all") for i in range(90): bn.update() self.send_email.assert_not_called() self.incr_time(minutes=1) bn.add_failure("Task(a=5)", "Task", {"a": 5}, "error", []) for i in range(30): bn.update() self.send_email.assert_not_called() self.incr_time(minutes=1) bn.update() self.check_email_send("Luigi: 1 failure in the last 60 minutes", "- Task(a=5) (1 failure)") def test_no_auto_send_for_interval_after_exception(self): bn = BatchNotifier(batch_mode="all") bn.add_failure("Task(a=5)", "Task", {"a": 5}, "error", []) self.send_email.side_effect = SMTPServerDisconnected self.incr_time(minutes=60) self.assertRaises(SMTPServerDisconnected, bn.update) self.send_email.reset_mock() self.send_email.side_effect = None bn.add_failure("Task(a=5)", "Task", {"a": 5}, "error", []) for i in range(60): bn.update() self.send_email.assert_not_called() self.incr_time(minutes=1) bn.update() self.assertEqual(1, self.send_email.call_count) def test_send_batch_failure_emails_to_owners(self): bn = BatchNotifier(batch_mode="all") bn.add_failure("Task(a=1)", "Task", {"a": "1"}, "error", ["a@test.com", "b@test.com"]) bn.add_failure("Task(a=1)", "Task", {"a": "1"}, "error", ["b@test.com"]) bn.add_failure("Task(a=2)", "Task", {"a": "2"}, "error", ["a@test.com"]) bn.send_email() send_calls = [ mock.call( "Luigi: 3 failures in the last 60 minutes", "- Task(a=1) (2 failures)\n- Task(a=2) (1 failure)", "sender@test.com", ("r@test.com",), ), mock.call( "Luigi: Your tasks have 2 failures in the last 60 minutes", "- Task(a=1) (1 failure)\n- Task(a=2) (1 failure)", "sender@test.com", ("a@test.com",), ), mock.call( "Luigi: Your tasks have 2 failures in the last 60 minutes", "- Task(a=1) (2 failures)", "sender@test.com", ("b@test.com",), ), ] self.send_email.assert_has_calls(send_calls, any_order=True) def test_send_batch_disable_email_to_owners(self): bn = BatchNotifier(batch_mode="all") bn.add_disable("Task(a=1)", "Task", {"a": "1"}, ["a@test.com"]) bn.send_email() send_calls = [ mock.call( "Luigi: 1 disable in the last 60 minutes", "- Task(a=1) (1 disable)", "sender@test.com", ("r@test.com",), ), mock.call( "Luigi: Your tasks have 1 disable in the last 60 minutes", "- Task(a=1) (1 disable)", "sender@test.com", ("a@test.com",), ), ] self.send_email.assert_has_calls(send_calls, any_order=True) def test_batch_identical_expls(self): bn = BatchNotifier(error_messages=1, group_by_error_messages=True) bn.add_failure("Task(a=1)", "Task", {"a": "1"}, "msg1", []) bn.add_failure("Task(a=2)", "Task", {"a": "2"}, "msg1", []) bn.add_failure("Task(a=3)", "Task", {"a": "3"}, "msg1", []) bn.add_failure("Task(a=4)", "Task", {"a": "4"}, "msg2", []) bn.add_failure("Task(a=4)", "Task", {"a": "4"}, "msg2", []) bn.send_email() self.check_email_send( "Luigi: 5 failures in the last 60 minutes", "- Task(a=1) (1 failure)\n Task(a=2) (1 failure)\n Task(a=3) (1 failure)\n\n msg1\n\n- Task(a=4) (2 failures)\n\n msg2", ) def test_batch_identical_expls_html(self): self.email().format = "html" bn = BatchNotifier(error_messages=1, group_by_error_messages=True) bn.add_failure("Task(a=1)", "Task", {"a": "1"}, "msg1", []) bn.add_failure("Task(a=2)", "Task", {"a": "2"}, "msg1", []) bn.add_failure("Task(a=3)", "Task", {"a": "3"}, "msg1", []) bn.add_failure("Task(a=4)", "Task", {"a": "4"}, "msg2", []) bn.add_failure("Task(a=4)", "Task", {"a": "4"}, "msg2", []) bn.send_email() self.check_email_send( "Luigi: 5 failures in the last 60 minutes", "
      \n" "
    • Task(a=1) (1 failure)\n" "
      Task(a=2) (1 failure)\n" "
      Task(a=3) (1 failure)\n" "
      msg1
      \n" "
    • Task(a=4) (2 failures)\n" "
      msg2
      \n" "
    ", ) def test_unicode_error_message(self): bn = BatchNotifier(error_messages=1) bn.add_failure("Task()", "Task", {}, "Érror", []) bn.send_email() self.check_email_send("Luigi: 1 failure in the last 60 minutes", "- Task() (1 failure)\n\n Érror") def test_unicode_error_message_html(self): self.email().format = "html" bn = BatchNotifier(error_messages=1) bn.add_failure("Task()", "Task", {}, "Érror", []) bn.send_email() self.check_email_send("Luigi: 1 failure in the last 60 minutes", "
      \n
    • Task() (1 failure)\n
      Érror
      \n
    ") def test_unicode_param_value(self): for batch_mode in ("all", "unbatched_params"): self.send_email.reset_mock() bn = BatchNotifier(batch_mode=batch_mode) bn.add_failure("Task(a=á)", "Task", {"a": "á"}, "error", []) bn.send_email() self.check_email_send("Luigi: 1 failure in the last 60 minutes", "- Task(a=á) (1 failure)") def test_unicode_param_value_html(self): self.email().format = "html" for batch_mode in ("all", "unbatched_params"): self.send_email.reset_mock() bn = BatchNotifier(batch_mode=batch_mode) bn.add_failure("Task(a=á)", "Task", {"a": "á"}, "error", []) bn.send_email() self.check_email_send("Luigi: 1 failure in the last 60 minutes", "
      \n
    • Task(a=á) (1 failure)\n
    ") def test_unicode_param_name(self): for batch_mode in ("all", "unbatched_params"): self.send_email.reset_mock() bn = BatchNotifier(batch_mode=batch_mode) bn.add_failure("Task(á=a)", "Task", {"á": "a"}, "error", []) bn.send_email() self.check_email_send("Luigi: 1 failure in the last 60 minutes", "- Task(á=a) (1 failure)") def test_unicode_param_name_html(self): self.email().format = "html" for batch_mode in ("all", "unbatched_params"): self.send_email.reset_mock() bn = BatchNotifier(batch_mode=batch_mode) bn.add_failure("Task(á=a)", "Task", {"á": "a"}, "error", []) bn.send_email() self.check_email_send("Luigi: 1 failure in the last 60 minutes", "
      \n
    • Task(á=a) (1 failure)\n
    ") def test_unicode_class_name(self): bn = BatchNotifier() bn.add_failure("Tásk()", "Tásk", {}, "error", []) bn.send_email() self.check_email_send("Luigi: 1 failure in the last 60 minutes", "- Tásk() (1 failure)") def test_unicode_class_name_html(self): self.email().format = "html" bn = BatchNotifier() bn.add_failure("Tásk()", "Tásk", {}, "error", []) bn.send_email() self.check_email_send("Luigi: 1 failure in the last 60 minutes", "
      \n
    • Tásk() (1 failure)\n
    ") ================================================ FILE: test/choice_parameter_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from helpers import unittest import luigi class ChoiceParameterTest(unittest.TestCase): def test_parse_str(self): d = luigi.ChoiceParameter(choices=["1", "2", "3"]) self.assertEqual("3", d.parse("3")) def test_parse_int(self): d = luigi.ChoiceParameter(var_type=int, choices=[1, 2, 3]) self.assertEqual(3, d.parse(3)) def test_parse_int_conv(self): d = luigi.ChoiceParameter(var_type=int, choices=[1, 2, 3]) self.assertEqual(3, d.parse("3")) def test_invalid_choice(self): d = luigi.ChoiceParameter(choices=["1", "2", "3"]) self.assertRaises(ValueError, lambda: d.parse("xyz")) def test_invalid_choice_type(self): self.assertRaises(AssertionError, lambda: luigi.ChoiceParameter(var_type=int, choices=[1, 2, "3"])) def test_choices_parameter_exception(self): self.assertRaises(luigi.parameter.ParameterException, lambda: luigi.ChoiceParameter(var_type=int)) def test_hash_str(self): class Foo(luigi.Task): args = luigi.ChoiceParameter(var_type=str, choices=["1", "2", "3"]) p = luigi.ChoiceParameter(var_type=str, choices=["3", "2", "1"]) self.assertEqual(hash(Foo(args="3").args), hash(p.parse("3"))) def test_serialize_parse(self): a = luigi.ChoiceParameter(var_type=str, choices=["1", "2", "3"]) b = "3" self.assertEqual(b, a.parse(a.serialize(b))) def test_invalid_choice_task(self): class Foo(luigi.Task): args = luigi.ChoiceParameter(var_type=str, choices=["1", "2", "3"]) self.assertRaises(ValueError, lambda: Foo(args="4")) ================================================ FILE: test/clone_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from helpers import unittest import luigi import luigi.notifications luigi.notifications.DEBUG = True class LinearSum(luigi.Task): lo = luigi.IntParameter() hi = luigi.IntParameter() def requires(self): if self.hi > self.lo: return self.clone(hi=self.hi - 1) def run(self): if self.hi > self.lo: self.s = self.requires().s + self.f(self.hi - 1) else: self.s = 0 self.complete = lambda: True # workaround since we don't write any output def complete(self): return False def f(self, x): return x class PowerSum(LinearSum): p = luigi.IntParameter() def f(self, x): return x**self.p class CloneTest(unittest.TestCase): def test_args(self): t = LinearSum(lo=42, hi=45) self.assertEqual(t.param_args, (42, 45)) self.assertEqual(t.param_kwargs, {"lo": 42, "hi": 45}) def test_recursion(self): t = LinearSum(lo=42, hi=45) luigi.build([t], local_scheduler=True) self.assertEqual(t.s, 42 + 43 + 44) def test_inheritance(self): t = PowerSum(lo=42, hi=45, p=2) luigi.build([t], local_scheduler=True) self.assertEqual(t.s, 42**2 + 43**2 + 44**2) def test_inheritance_from_non_parameter(self): """ Cloning can pull non-source-parameters from source to target parameter. """ class SubTask(luigi.Task): lo = 1 @property def hi(self): return 2 t1 = SubTask() t2 = t1.clone(cls=LinearSum) self.assertEqual(t2.lo, 1) self.assertEqual(t2.hi, 2) ================================================ FILE: test/cmdline_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import os import subprocess import mock from helpers import unittest import luigi import luigi.cmdline from luigi.configuration import LuigiTomlParser, get_config from luigi.mock import MockTarget from luigi.setup_logging import DaemonLogging, InterfaceLogging class SomeTask(luigi.Task): n = luigi.IntParameter() def output(self): return MockTarget("/tmp/test_%d" % self.n) def run(self): f = self.output().open("w") f.write("done") f.close() class AmbiguousClass(luigi.Task): pass class AmbiguousClass(luigi.Task): # NOQA pass class TaskWithSameName(luigi.Task): def run(self): self.x = 42 class TaskWithSameName(luigi.Task): # NOQA # there should be no ambiguity def run(self): self.x = 43 class WriteToFile(luigi.Task): filename = luigi.Parameter() def output(self): return luigi.LocalTarget(self.filename) def run(self): f = self.output().open("w") print("foo", file=f) f.close() class FooBaseClass(luigi.Task): x = luigi.Parameter(default="foo_base_default") class FooSubClass(FooBaseClass): pass class ATaskThatFails(luigi.Task): def run(self): raise ValueError() class RequiredConfig(luigi.Config): required_test_param = luigi.Parameter() class TaskThatRequiresConfig(luigi.WrapperTask): def requires(self): if RequiredConfig().required_test_param == "A": return SubTaskThatFails() class SubTaskThatFails(luigi.Task): def complete(self): return False def run(self): raise Exception() class CmdlineTest(unittest.TestCase): def setUp(self): MockTarget.fs.clear() DaemonLogging._configured = False def tearDown(self): DaemonLogging._configured = False DaemonLogging.config = get_config() InterfaceLogging.config = get_config() def _clean_config(self): DaemonLogging.config = LuigiTomlParser() DaemonLogging.config.data = {} def _restore_config(self): DaemonLogging.config = LuigiTomlParser.instance() @mock.patch("logging.getLogger") def test_cmdline_main_task_cls(self, logger): luigi.run(["--local-scheduler", "--no-lock", "--n", "100"], main_task_cls=SomeTask) self.assertEqual(dict(MockTarget.fs.get_all_data()), {"/tmp/test_100": b"done"}) @mock.patch("logging.getLogger") def test_cmdline_local_scheduler(self, logger): luigi.run(["SomeTask", "--no-lock", "--n", "101"], local_scheduler=True) self.assertEqual(dict(MockTarget.fs.get_all_data()), {"/tmp/test_101": b"done"}) @mock.patch("logging.getLogger") def test_cmdline_other_task(self, logger): luigi.run(["--local-scheduler", "--no-lock", "SomeTask", "--n", "1000"]) self.assertEqual(dict(MockTarget.fs.get_all_data()), {"/tmp/test_1000": b"done"}) @mock.patch("logging.getLogger") def test_cmdline_ambiguous_class(self, logger): self.assertRaises(Exception, luigi.run, ["--local-scheduler", "--no-lock", "AmbiguousClass"]) @mock.patch("logging.getLogger") @mock.patch("logging.StreamHandler") def test_setup_interface_logging(self, handler, logger): opts = type("opts", (), {}) opts.background = False opts.logdir = False opts.logging_conf_file = None opts.log_level = "INFO" handler.return_value = mock.Mock(name="stream_handler") InterfaceLogging._configured = False InterfaceLogging.config = LuigiTomlParser() InterfaceLogging.config.data = {} InterfaceLogging.setup(opts) self.assertEqual([mock.call(handler.return_value)], logger.return_value.addHandler.call_args_list) InterfaceLogging._configured = False opts.logging_conf_file = "/blah" with self.assertRaises(OSError): InterfaceLogging.setup(opts) InterfaceLogging._configured = False @mock.patch("argparse.ArgumentParser.print_usage") def test_non_existent_class(self, print_usage): self.assertRaises(luigi.task_register.TaskClassNotFoundException, luigi.run, ["--local-scheduler", "--no-lock", "XYZ"]) @mock.patch("argparse.ArgumentParser.print_usage") def test_no_task(self, print_usage): self.assertRaises(SystemExit, luigi.run, ["--local-scheduler", "--no-lock"]) def test_luigid_logging_conf(self): with mock.patch("luigi.server.run") as server_run, mock.patch("logging.config.fileConfig") as fileConfig: luigi.cmdline.luigid([]) self.assertTrue(server_run.called) # the default test configuration specifies a logging conf file fileConfig.assert_called_with("test/testconfig/logging.cfg") def test_luigid_no_logging_conf(self): with mock.patch("luigi.server.run") as server_run, mock.patch("logging.basicConfig") as basicConfig: self._clean_config() DaemonLogging.config.data = { "core": { "no_configure_logging": False, "logging_conf_file": None, } } luigi.cmdline.luigid([]) self.assertTrue(server_run.called) self.assertTrue(basicConfig.called) def test_luigid_missing_logging_conf(self): with mock.patch("luigi.server.run") as server_run, mock.patch("logging.basicConfig") as basicConfig: self._restore_config() DaemonLogging.config.data = { "core": { "no_configure_logging": False, "logging_conf_file": "nonexistent.cfg", } } self.assertRaises(Exception, luigi.cmdline.luigid, []) self.assertFalse(server_run.called) self.assertFalse(basicConfig.called) class InvokeOverCmdlineTest(unittest.TestCase): def _run_cmdline(self, args): env = os.environ.copy() env["PYTHONPATH"] = env.get("PYTHONPATH", "") + ":.:test" print("Running: " + " ".join(args)) # To simplify rerunning failing tests p = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env) stdout, stderr = p.communicate() # Unfortunately subprocess.check_output is 2.7+ return p.returncode, stdout, stderr def test_bin_luigi(self): t = luigi.LocalTarget(is_tmp=True) args = ["./bin/luigi", "--module", "cmdline_test", "WriteToFile", "--filename", t.path, "--local-scheduler", "--no-lock"] self._run_cmdline(args) self.assertTrue(t.exists()) def test_direct_python(self): t = luigi.LocalTarget(is_tmp=True) args = ["python", "test/cmdline_test.py", "WriteToFile", "--filename", t.path, "--local-scheduler", "--no-lock"] self._run_cmdline(args) self.assertTrue(t.exists()) def test_python_module(self): t = luigi.LocalTarget(is_tmp=True) args = ["python", "-m", "luigi", "--module", "cmdline_test", "WriteToFile", "--filename", t.path, "--local-scheduler", "--no-lock"] self._run_cmdline(args) self.assertTrue(t.exists()) def test_direct_python_help(self): returncode, stdout, stderr = self._run_cmdline(["python", "test/cmdline_test.py", "--help-all"]) self.assertTrue(stdout.find(b"--FooBaseClass-x") != -1) self.assertFalse(stdout.find(b"--x") != -1) def test_direct_python_help_class(self): returncode, stdout, stderr = self._run_cmdline(["python", "test/cmdline_test.py", "FooBaseClass", "--help"]) self.assertTrue(stdout.find(b"--FooBaseClass-x") != -1) self.assertTrue(stdout.find(b"--x") != -1) def test_bin_luigi_help(self): returncode, stdout, stderr = self._run_cmdline(["./bin/luigi", "--module", "cmdline_test", "--help-all"]) self.assertTrue(stdout.find(b"--FooBaseClass-x") != -1) self.assertFalse(stdout.find(b"--x") != -1) def test_python_module_luigi_help(self): returncode, stdout, stderr = self._run_cmdline(["python", "-m", "luigi", "--module", "cmdline_test", "--help-all"]) self.assertTrue(stdout.find(b"--FooBaseClass-x") != -1) self.assertFalse(stdout.find(b"--x") != -1) def test_bin_luigi_help_no_module(self): returncode, stdout, stderr = self._run_cmdline(["./bin/luigi", "--help"]) self.assertTrue(stdout.find(b"usage:") != -1) def test_bin_luigi_help_not_spammy(self): """ Test that `luigi --help` fits on one screen """ returncode, stdout, stderr = self._run_cmdline(["./bin/luigi", "--help"]) self.assertLessEqual(len(stdout.splitlines()), 15) def test_bin_luigi_all_help_spammy(self): """ Test that `luigi --help-all` doesn't fit on a screen Naturally, I don't mind this test breaking, but it convinces me that the "not spammy" test is actually testing what it claims too. """ returncode, stdout, stderr = self._run_cmdline(["./bin/luigi", "--help-all"]) self.assertGreater(len(stdout.splitlines()), 15) def test_error_mesage_on_misspelled_task(self): returncode, stdout, stderr = self._run_cmdline(["./bin/luigi", "RangeDaili"]) self.assertTrue(stderr.find(b"RangeDaily") != -1) def test_bin_luigi_no_parameters(self): returncode, stdout, stderr = self._run_cmdline(["./bin/luigi"]) self.assertTrue(stderr.find(b"No task specified") != -1) def test_python_module_luigi_no_parameters(self): returncode, stdout, stderr = self._run_cmdline(["python", "-m", "luigi"]) self.assertTrue(stderr.find(b"No task specified") != -1) def test_bin_luigi_help_class(self): returncode, stdout, stderr = self._run_cmdline(["./bin/luigi", "--module", "cmdline_test", "FooBaseClass", "--help"]) self.assertTrue(stdout.find(b"--FooBaseClass-x") != -1) self.assertTrue(stdout.find(b"--x") != -1) def test_python_module_help_class(self): returncode, stdout, stderr = self._run_cmdline(["python", "-m", "luigi", "--module", "cmdline_test", "FooBaseClass", "--help"]) self.assertTrue(stdout.find(b"--FooBaseClass-x") != -1) self.assertTrue(stdout.find(b"--x") != -1) def test_bin_luigi_options_before_task(self): args = ["./bin/luigi", "--module", "cmdline_test", "--no-lock", "--local-scheduler", "--FooBaseClass-x", "hello", "FooBaseClass"] returncode, stdout, stderr = self._run_cmdline(args) self.assertEqual(0, returncode) def test_bin_fail_on_unrecognized_args(self): returncode, stdout, stderr = self._run_cmdline(["./bin/luigi", "--no-lock", "--local-scheduler", "Task", "--unknown-param", "hiiii"]) self.assertNotEqual(0, returncode) def test_deps_py_script(self): """ Test the deps.py script. """ args = "python luigi/tools/deps.py --module examples.top_artists ArtistToplistToDatabase --date-interval 2015-W10".split() returncode, stdout, stderr = self._run_cmdline(args) self.assertEqual(0, returncode) self.assertTrue(stdout.find(b"[FileSystem] data/streams_2015_03_04_faked.tsv") != -1) self.assertTrue(stdout.find(b"[DB] localhost") != -1) def test_deps_tree_py_script(self): """ Test the deps_tree.py script. """ args = "python luigi/tools/deps_tree.py --module examples.top_artists AggregateArtists --date-interval 2012-06".split() returncode, stdout, stderr = self._run_cmdline(args) self.assertEqual(0, returncode) for i in range(1, 30): self.assertTrue(stdout.find(("-[Streams-{{'date': '2012-06-{0}'}}".format(str(i).zfill(2))).encode("utf-8")) != -1) def test_bin_mentions_misspelled_task(self): """ Test that the error message is informative when a task is misspelled. In particular it should say that the task is misspelled and not that the local parameters do not exist. """ returncode, stdout, stderr = self._run_cmdline(["./bin/luigi", "--module", "cmdline_test", "HooBaseClass", "--x 5"]) self.assertTrue(stderr.find(b"FooBaseClass") != -1) self.assertTrue(stderr.find(b"--x") != 0) def test_stack_trace_has_no_inner(self): """ Test that the stack trace for failing tasks are short The stack trace shouldn't contain unreasonably much implementation details of luigi In particular it should say that the task is misspelled and not that the local parameters do not exist. """ returncode, stdout, stderr = self._run_cmdline(["./bin/luigi", "--module", "cmdline_test", "ATaskThatFails", "--local-scheduler", "--no-lock"]) print(stdout) self.assertFalse(stdout.find(b"run() got an unexpected keyword argument 'tracking_url_callback'") != -1) self.assertFalse(stdout.find(b"During handling of the above exception, another exception occurred") != -1) def test_cmd_line_params_are_available_for_execution_summary(self): """ Test that config parameters specified on the command line are available while generating the execution summary. """ returncode, stdout, stderr = self._run_cmdline( [ "./bin/luigi", "--module", "cmdline_test", "TaskThatRequiresConfig", "--local-scheduler", "--no-lock--RequiredConfig-required-test-param", "A", ] ) print(stdout) print(stderr) self.assertNotEqual(returncode, 1) self.assertFalse(b"required_test_param" in stderr) if __name__ == "__main__": # Needed for one of the tests luigi.run() ================================================ FILE: test/config_env_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2018 Vote inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import os from helpers import LuigiTestCase, with_config from luigi.configuration import LuigiConfigParser, LuigiTomlParser, get_config from luigi.configuration.cfg_parser import InterpolationMissingEnvvarError class ConfigParserTest(LuigiTestCase): environ = { "TESTVAR": "1", } def setUp(self): self.environ_backup = {os.environ[key] for key in self.environ if key in os.environ} for key, value in self.environ.items(): os.environ[key] = value LuigiConfigParser._instance = None super(ConfigParserTest, self).setUp() def tearDown(self): for key in self.environ: os.environ.pop(key) for key, value in self.environ_backup: os.environ[key] = value if "LUIGI_CONFIG_PARSER" in os.environ: del os.environ["LUIGI_CONFIG_PARSER"] @with_config( { "test": { "a": "testval", "b": "%(a)s", "c": "%(a)s%(a)s", } } ) def test_basic_interpolation(self): # Make sure the default ConfigParser behaviour is not broken config = get_config() self.assertEqual(config.get("test", "b"), config.get("test", "a")) self.assertEqual(config.get("test", "c"), 2 * config.get("test", "a")) @with_config( { "test": { "a": "${TESTVAR}", "b": "${TESTVAR} ${TESTVAR}", "c": "${TESTVAR} %(a)s", "d": "${NONEXISTING}", } } ) def test_env_interpolation(self): config = get_config() self.assertEqual(config.get("test", "a"), "1") self.assertEqual(config.getint("test", "a"), 1) self.assertEqual(config.getboolean("test", "a"), True) self.assertEqual(config.get("test", "b"), "1 1") self.assertEqual(config.get("test", "c"), "1 1") with self.assertRaises(InterpolationMissingEnvvarError): config.get("test", "d") @with_config( { "test": { "foo-bar": "fob", "baz_qux": "bax", } } ) def test_underscore_vs_dash_style(self): config = get_config() self.assertEqual(config.get("test", "foo-bar"), "fob") self.assertEqual(config.get("test", "foo_bar"), "fob") self.assertEqual(config.get("test", "baz-qux"), "bax") self.assertEqual(config.get("test", "baz_qux"), "bax") @with_config( { "test": { "foo-bar": "fob", "foo_bar": "bax", } } ) def test_underscore_vs_dash_style_priority(self): config = get_config() self.assertEqual(config.get("test", "foo-bar"), "bax") self.assertEqual(config.get("test", "foo_bar"), "bax") def test_default_parser(self): config = get_config() self.assertIsInstance(config, LuigiConfigParser) os.environ["LUIGI_CONFIG_PARSER"] = "toml" config = get_config() self.assertIsInstance(config, LuigiTomlParser) ================================================ FILE: test/config_toml_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2018 Vote inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from helpers import LuigiTestCase from luigi.configuration import LuigiTomlParser, add_config_path, get_config class TomlConfigParserTest(LuigiTestCase): @classmethod def setUpClass(cls): add_config_path("test/testconfig/luigi.toml") add_config_path("test/testconfig/luigi_local.toml") def setUp(self): LuigiTomlParser._instance = None super(TomlConfigParserTest, self).setUp() def test_get_config(self): config = get_config("toml") self.assertIsInstance(config, LuigiTomlParser) def test_file_reading(self): config = get_config("toml") self.assertIn("hdfs", config.data) def test_get(self): config = get_config("toml") # test getting self.assertEqual(config.get("hdfs", "client"), "hadoopcli") self.assertEqual(config.get("hdfs", "client", "test"), "hadoopcli") # test default self.assertEqual(config.get("hdfs", "test", "check"), "check") with self.assertRaises(KeyError): config.get("hdfs", "test") # test override self.assertEqual(config.get("hdfs", "namenode_host"), "localhost") # test non-string values self.assertEqual(config.get("hdfs", "namenode_port"), 50030) def test_set(self): config = get_config("toml") self.assertEqual(config.get("hdfs", "client"), "hadoopcli") config.set("hdfs", "client", "test") self.assertEqual(config.get("hdfs", "client"), "test") config.set("hdfs", "check", "test me") self.assertEqual(config.get("hdfs", "check"), "test me") def test_has_option(self): config = get_config("toml") self.assertTrue(config.has_option("hdfs", "client")) self.assertFalse(config.has_option("hdfs", "nope")) self.assertFalse(config.has_option("nope", "client")) class HelpersTest(LuigiTestCase): def test_add_without_install(self): enabled = LuigiTomlParser.enabled LuigiTomlParser.enabled = False with self.assertRaises(ImportError): add_config_path("test/testconfig/luigi.toml") LuigiTomlParser.enabled = enabled def test_get_without_install(self): enabled = LuigiTomlParser.enabled LuigiTomlParser.enabled = False with self.assertRaises(ImportError): get_config("toml") LuigiTomlParser.enabled = enabled ================================================ FILE: test/conftest.py ================================================ from typing import List import pytest import luigi.task_register @pytest.fixture(autouse=True) def reset_luigi_registry(): """Reset the Luigi task registry before and after each test. Prevents registry pollution between tests when running with pytest-xdist, where multiple tests execute sequentially within the same worker process. This mirrors the behaviour of LuigiTestCase.setUp/tearDown and applies it to all tests automatically, including those that inherit unittest.TestCase directly without going through LuigiTestCase. """ original = luigi.task_register.Register._get_reg() luigi.task_register.Register.clear_instance_cache() yield luigi.task_register.Register._set_reg(original) luigi.task_register.Register.clear_instance_cache() def pytest_collection_modifyitems(items: List[pytest.Item]) -> None: """ Automatically add the equivalent of pytest.mark.unmarked to any test which has no markers For example, enables the ability to target "contrib + unmarked" tests (eventually getting rid of the generic "contrib" marker): - pytest test/contrib/ -m "contrib or unmarked" """ for item in items: # Check if the item has any markers (custom or builtin) if not any(item.iter_markers()): item.add_marker(pytest.mark.unmarked) ================================================ FILE: test/contrib/__init__.py ================================================ ================================================ FILE: test/contrib/_webhdfs_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import os import pytest from helpers import unittest from luigi.contrib import webhdfs @pytest.mark.apache class TestWebHdfsTarget(unittest.TestCase): """ This test requires a running Hadoop cluster with WebHdfs enabled This test requires the luigi.cfg file to have a `hdfs` section with the namenode_host, namenode_port and user settings. """ def setUp(self): self.testDir = "/tmp/luigi-test".format() self.path = os.path.join(self.testDir, "out.txt") self.client = webhdfs.WebHdfsClient() self.target = webhdfs.WebHdfsTarget(self.path) def tearDown(self): if self.client.exists(self.testDir): self.client.remove(self.testDir, recursive=True) def test_write(self): self.assertFalse(self.client.exists(self.path)) output = self.target.open("w") output.write("this is line 1\n") output.write("this is line #2\n") output.close() self.assertTrue(self.client.exists(self.path)) def test_read(self): self.test_write() input_ = self.target.open("r") all_test = "this is line 1\nthis is line #2\n" self.assertEqual(all_test, input_.read()) input_.close() def test_read_lines(self): self.test_write() input_ = self.target.open("r") lines = list(input_.readlines()) self.assertEqual(lines[0], "this is line 1") self.assertEqual(lines[1], "this is line #2") input_.close() ================================================ FILE: test/contrib/azureblob_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2018 Microsoft Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ Integration tests for azureblob module. """ import json import os import unittest import pytest import luigi from luigi.contrib.azureblob import AzureBlobClient, AzureBlobTarget from luigi.target import FileAlreadyExists account_name = os.environ.get("AZURITE_ACCOUNT_NAME") account_key = os.environ.get("AZURITE_ACCOUNT_KEY") sas_token = os.environ.get("AZURITE_SAS_TOKEN") custom_domain = os.environ.get("AZURITE_CUSTOM_DOMAIN") protocol = os.environ.get("AZURITE_PROTOCOL", "http") client = AzureBlobClient(account_name, account_key, sas_token, custom_domain=custom_domain, protocol=protocol) @pytest.mark.azureblob class AzureBlobClientTest(unittest.TestCase): def setUp(self): self.client = client def tearDown(self): pass def test_splitfilepath_blob_none(self): container, blob = self.client.splitfilepath("abc") self.assertEqual(container, "abc") self.assertIsNone(blob) def test_splitfilepath_blob_toplevel(self): container, blob = self.client.splitfilepath("abc/cde") self.assertEqual(container, "abc") self.assertEqual(blob, "cde") def test_splitfilepath_blob_nested(self): container, blob = self.client.splitfilepath("abc/cde/xyz.txt") self.assertEqual(container, "abc") self.assertEqual(blob, "cde/xyz.txt") def test_create_delete_container(self): import datetime import hashlib m = hashlib.new("md5", usedforsecurity=False) m.update(datetime.datetime.now().__str__().encode()) container_name = m.hexdigest() self.assertFalse(self.client.exists(container_name)) self.assertTrue(self.client.create_container(container_name)) self.assertTrue(self.client.exists(container_name)) self.client.delete_container(container_name) self.assertFalse(self.client.exists(container_name)) def test_upload_copy_move_remove_blob(self): import datetime import hashlib import tempfile m = hashlib.new("md5", usedforsecurity=False) m.update(datetime.datetime.now().__str__().encode()) container_name = m.hexdigest() m.update(datetime.datetime.now().__str__().encode()) from_blob_name = m.hexdigest() from_path = "{container_name}/{from_blob_name}".format(container_name=container_name, from_blob_name=from_blob_name) m.update(datetime.datetime.now().__str__().encode()) to_blob_name = m.hexdigest() to_path = "{container_name}/{to_blob_name}".format(container_name=container_name, to_blob_name=to_blob_name) message = datetime.datetime.now().__str__().encode() self.assertTrue(self.client.create_container(container_name)) with tempfile.NamedTemporaryFile() as f: f.write(message) f.flush() # upload self.client.upload(f.name, container_name, from_blob_name) self.assertTrue(self.client.exists(from_path)) # mkdir self.assertRaises(FileAlreadyExists, self.client.mkdir, from_path, False, True) # mkdir does not actually create anything self.client.mkdir(to_path, True, True) self.assertFalse(self.client.exists(to_path)) # copy self.assertIn(self.client.copy(from_path, to_path)["copy_status"], ["success", "pending"]) self.assertTrue(self.client.exists(to_path)) # remove self.assertTrue(self.client.remove(from_path)) self.assertFalse(self.client.exists(from_path)) # move back file self.client.move(to_path, from_path) self.assertTrue(self.client.exists(from_path)) self.assertFalse(self.client.exists(to_path)) self.assertTrue(self.client.remove(from_path)) self.assertFalse(self.client.exists(from_path)) # delete container self.client.delete_container(container_name) self.assertFalse(self.client.exists(container_name)) class MovieScriptTask(luigi.Task): def output(self): return AzureBlobTarget("luigi-test", "movie-cheesy.txt", client, download_when_reading=False) def run(self): client.create_container("luigi-test") with self.output().open("w") as op: op.write("I'm going to make him an offer he can't refuse.\n") op.write("Toto, I've got a feeling we're not in Kansas anymore.\n") op.write("May the Force be with you.\n") op.write("Bond. James Bond.\n") op.write("Greed, for lack of a better word, is good.\n") class AzureJsonDumpTask(luigi.Task): def output(self): return AzureBlobTarget("luigi-test", "stats.json", client) def run(self): with self.output().open("w") as op: json.dump([1, 2, 3], op) class FinalTask(luigi.Task): def requires(self): return {"movie": self.clone(MovieScriptTask), "np": self.clone(AzureJsonDumpTask)} def run(self): with self.input()["movie"].open("r") as movie, self.input()["np"].open("r") as np, self.output().open("w") as output: movie_lines = movie.read() assert "Toto, I've got a feeling" in movie_lines output.write(movie_lines) data = json.load(np) assert data == [1, 2, 3] output.write(data.__str__()) def output(self): return luigi.LocalTarget("samefile") @pytest.mark.azureblob class AzureBlobTargetTest(unittest.TestCase): def setUp(self): self.client = client def tearDown(self): pass def test_AzureBlobTarget(self): final_task = FinalTask() luigi.build([final_task], local_scheduler=True, log_level="NOTSET") output = final_task.output().open("r").read() assert "Toto" in output ================================================ FILE: test/contrib/batch_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2018 Outlier Bio, LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import pytest from helpers import skipOnTravisAndGithubActions, unittest import luigi.contrib.batch as batch try: import boto3 client = boto3.client("batch") except ImportError: raise unittest.SkipTest("boto3 is not installed. BatchTasks require boto3") class MockBotoBatchClient: def describe_job_queues(self): return {"jobQueues": [{"jobQueueName": "test_queue", "state": "ENABLED", "status": "VALID"}]} def list_jobs(self, jobQueue="", jobStatus=""): return {"jobSummaryList": [{"jobName": "test_job", "jobId": "abcd"}]} def describe_jobs(self, jobs=[]): return { "ResponseMetadata": {"HTTPStatusCode": 200}, "jobs": [{"status": "SUCCEEDED", "attempts": [{"container": {"logStreamName": "test_job_abcd_log_stream"}}]}], } def submit_job(self, jobDefinition="", jobName="", jobQueue="", parameters={}): return {"jobId": "abcd"} def register_job_definition(self, **kwargs): return {"ResponseMetadata": {"HTTPStatusCode": 200}} class MockBotoLogsClient: def get_log_events(self, logGroupName="", logStreamName="", startFromHead=True): return {"events": [{"message": "log line 1"}, {"message": "log line 2"}, {"message": "log line 3"}]} @pytest.mark.aws @skipOnTravisAndGithubActions("boto3 now importable. These tests need mocked") class BatchClientTest(unittest.TestCase): def setUp(self): self.bc = batch.BatchClient(poll_time=10) self.bc._client = MockBotoBatchClient() self.bc._log_client = MockBotoLogsClient() def test_get_active_queue(self): self.assertEqual(self.bc.get_active_queue(), "test_queue") def test_get_job_id_from_name(self): self.assertEqual(self.bc.get_job_id_from_name("test_job"), "abcd") def test_get_job_status(self): self.assertEqual(self.bc.get_job_status("abcd"), "SUCCEEDED") def test_get_logs(self): log_str = "log line 1\nlog line 2\nlog line 3" self.assertEqual(self.bc.get_logs("test_job_abcd_log_stream"), log_str) def test_submit_job(self): job_id = self.bc.submit_job("test_job_def", {"param1": "foo", "param2": "bar"}, job_name="test_job") self.assertEqual(job_id, "abcd") def test_submit_job_specific_queue(self): job_id = self.bc.submit_job("test_job_def", {"param1": "foo", "param2": "bar"}, job_name="test_job", queue="test_queue") self.assertEqual(job_id, "abcd") def test_submit_job_non_existant_queue(self): with self.assertRaises(Exception): self.bc.submit_job("test_job_def", {"param1": "foo", "param2": "bar"}, job_name="test_job", queue="non_existant_queue") def test_wait_on_job(self): job_id = self.bc.submit_job("test_job_def", {"param1": "foo", "param2": "bar"}, job_name="test_job") self.assertTrue(self.bc.wait_on_job(job_id)) def test_wait_on_job_failed(self): job_id = self.bc.submit_job("test_job_def", {"param1": "foo", "param2": "bar"}, job_name="test_job") self.bc.get_job_status = lambda x: "FAILED" with self.assertRaises(batch.BatchJobException) as context: self.bc.wait_on_job(job_id) self.assertTrue("log line 1" in context.exception) @pytest.mark.aws @skipOnTravisAndGithubActions("boto3 now importable. These tests need mocked") class BatchTaskTest(unittest.TestCase): def setUp(self): self.task = batch.BatchTask(job_definition="test_job_def", job_name="test_job", poll_time=10) ================================================ FILE: test/contrib/beam_dataflow_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2019 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import json import unittest import mock import pytest from mock import MagicMock, patch import luigi from luigi import local_target from luigi.contrib import beam_dataflow, bigquery, gcs class TestDataflowParamKeys(beam_dataflow.DataflowParamKeys): runner = "runner" project = "project" zone = "zone" region = "region" staging_location = "stagingLocation" temp_location = "tempLocation" gcp_temp_location = "gcpTempLocation" num_workers = "numWorkers" autoscaling_algorithm = "autoscalingAlgorithm" max_num_workers = "maxNumWorkers" disk_size_gb = "diskSizeGb" worker_machine_type = "workerMachineType" worker_disk_type = "workerDiskType" job_name = "jobName" service_account = "serviceAccount" network = "network" subnetwork = "subnetwork" labels = "labels" class TestRequires(luigi.ExternalTask): def output(self): return luigi.LocalTarget(path="some-input-dir") class SimpleTestTask(beam_dataflow.BeamDataflowJobTask): dataflow_params = TestDataflowParamKeys() def requires(self): return TestRequires() def output(self): return local_target.LocalTarget(path="some-output.txt") def dataflow_executable(self): return ["java", "com.spotify.luigi.SomeJobClass"] class FullTestTask(beam_dataflow.BeamDataflowJobTask): project = "some-project" runner = "DirectRunner" temp_location = "some-temp" staging_location = "some-staging" gcp_temp_location = "some-gcp-temp" num_workers = 1 autoscaling_algorithm = "THROUGHPUT_BASED" max_num_workers = 2 network = "some-network" subnetwork = "some-subnetwork" disk_size_gb = 5 worker_machine_type = "n1-standard-4" job_name = "SomeJobName" worker_disk_type = "compute.googleapis.com/projects//zones//diskTypes/pd-ssd" service_account = "some-service-account@google.com" zone = "europe-west1-c" region = "europe-west1" labels = {"k1": "v1"} dataflow_params = TestDataflowParamKeys() def requires(self): return TestRequires() def output(self): return {"output": luigi.LocalTarget(path="some-output.txt")} def args(self): return ["--extraArg=present"] def dataflow_executable(self): return ["java", "com.spotify.luigi.SomeJobClass"] class FilePatternsTestTask(beam_dataflow.BeamDataflowJobTask): dataflow_params = TestDataflowParamKeys() def requires(self): return {"input1": TestRequires(), "input2": TestRequires()} def file_pattern(self): return {"input2": "*.some-ext"} def output(self): return {"output": luigi.LocalTarget(path="some-output.txt")} def dataflow_executable(self): return ["java", "com.spotify.luigi.SomeJobClass"] class DummyCmdLineTestTask(beam_dataflow.BeamDataflowJobTask): dataflow_params = TestDataflowParamKeys() def dataflow_executable(self): pass def requires(self): return {} def output(self): return {} def _mk_cmd_line(self): return ["echo", '"hello world"'] @pytest.mark.gcloud class BeamDataflowTest(unittest.TestCase): def test_dataflow_simple_cmd_line_args(self): task = SimpleTestTask() task.runner = "DirectRunner" expected = ["java", "com.spotify.luigi.SomeJobClass", "--runner=DirectRunner", "--input=some-input-dir/part-*", "--output=some-output.txt"] self.assertEqual(task._mk_cmd_line(), expected) def test_dataflow_full_cmd_line_args(self): full_test_task = FullTestTask() cmd_line_args = full_test_task._mk_cmd_line() expected = [ "java", "com.spotify.luigi.SomeJobClass", "--runner=DirectRunner", "--project=some-project", "--zone=europe-west1-c", "--region=europe-west1", "--stagingLocation=some-staging", "--tempLocation=some-temp", "--gcpTempLocation=some-gcp-temp", "--numWorkers=1", "--autoscalingAlgorithm=THROUGHPUT_BASED", "--maxNumWorkers=2", "--diskSizeGb=5", "--workerMachineType=n1-standard-4", "--workerDiskType=compute.googleapis.com/projects//zones//diskTypes/pd-ssd", "--network=some-network", "--subnetwork=some-subnetwork", "--jobName=SomeJobName", "--serviceAccount=some-service-account@google.com", '--labels={"k1": "v1"}', "--extraArg=present", "--input=some-input-dir/part-*", "--output=some-output.txt", ] self.assertEqual(json.loads(cmd_line_args[19][9:]), {"k1": "v1"}) self.assertEqual(cmd_line_args, expected) def test_dataflow_with_file_patterns(self): cmd_line_args = FilePatternsTestTask()._mk_cmd_line() self.assertIn("--input1=some-input-dir/part-*", cmd_line_args) self.assertIn("--input2=some-input-dir/*.some-ext", cmd_line_args) def test_dataflow_with_invalid_file_patterns(self): task = FilePatternsTestTask() task.file_pattern = MagicMock(return_value="notadict") with self.assertRaises(ValueError): task._mk_cmd_line() def test_dataflow_input_arg_formatting(self): class TestTaskListOfTargetsInput(SimpleTestTask): class TestRequiresListOfTargets(luigi.ExternalTask): def output(self): return [luigi.LocalTarget(path="some-input-1"), luigi.LocalTarget(path="some-input-2")] def requires(self): return self.TestRequiresListOfTargets() task_list_input = TestTaskListOfTargetsInput() self.assertEqual(task_list_input._format_input_args(), ["--input=some-input-1/part-*,some-input-2/part-*"]) class TestTaskListOfTuplesInput(SimpleTestTask): class TestRequiresListOfTuples(luigi.ExternalTask): def output(self): return [("input1", luigi.LocalTarget(path="some-input-1")), ("input2", luigi.LocalTarget(path="some-input-2"))] def requires(self): return self.TestRequiresListOfTuples() task_list_tuples_input = TestTaskListOfTuplesInput() self.assertEqual(task_list_tuples_input._format_input_args(), ["--input1=some-input-1/part-*", "--input2=some-input-2/part-*"]) class TestTaskDictInput(SimpleTestTask): class TestRequiresDict(luigi.ExternalTask): def output(self): return {"input1": luigi.LocalTarget(path="some-input-1"), "input2": luigi.LocalTarget(path="some-input-2")} def requires(self): return self.TestRequiresDict() task_dict_input = TestTaskDictInput() self.assertEqual(task_dict_input._format_input_args(), ["--input1=some-input-1/part-*", "--input2=some-input-2/part-*"]) class TestTaskTupleInput(SimpleTestTask): class TestRequiresTuple(luigi.ExternalTask): def output(self): return "some-key", luigi.LocalTarget(path="some-input") def requires(self): return self.TestRequiresTuple() task_tuple_input = TestTaskTupleInput() self.assertEqual(task_tuple_input._format_input_args(), ["--some-key=some-input/part-*"]) def test_task_output_arg_completion(self): class TestCompleteTarget(luigi.Target): def exists(self): return True class TestIncompleteTarget(luigi.Target): def exists(self): return False class TestTaskDictOfCompleteOutput(SimpleTestTask): def output(self): return {"output": TestCompleteTarget()} self.assertEqual(TestTaskDictOfCompleteOutput().complete(), True) class TestTaskDictOfIncompleteOutput(SimpleTestTask): def output(self): return {"output": TestIncompleteTarget()} self.assertEqual(TestTaskDictOfIncompleteOutput().complete(), False) class TestTaskDictOfMixedCompleteOutput(SimpleTestTask): def output(self): return {"output1": TestIncompleteTarget(), "output2": TestCompleteTarget()} self.assertEqual(TestTaskDictOfMixedCompleteOutput().complete(), False) def test_get_target_path(self): bq_target = bigquery.BigQueryTarget("p", "d", "t", client="fake_client") self.assertEqual(SimpleTestTask.get_target_path(bq_target), "p:d.t") gcs_target = gcs.GCSTarget("gs://foo/bar.txt", client="fake_client") self.assertEqual(SimpleTestTask.get_target_path(gcs_target), "gs://foo/bar.txt") with self.assertRaises(ValueError): SimpleTestTask.get_target_path("not_a_target") def test_dataflow_runner_resolution(self): task = SimpleTestTask() # Test that supported runners are passed through for runner in ["DirectRunner", "DataflowRunner"]: task.runner = runner self.assertEqual(task._get_runner(), runner) # Test that unsupported runners throw an error task.runner = "UnsupportedRunner" with self.assertRaises(ValueError): task._get_runner() def test_dataflow_successful_run_callbacks(self): task = DummyCmdLineTestTask() task.before_run = MagicMock() task.validate_output = MagicMock() task.on_successful_run = MagicMock() task.on_successful_output_validation = MagicMock() task.cleanup_on_error = MagicMock() task.run() task.before_run.assert_called_once_with() task.validate_output.assert_called_once_with() task.cleanup_on_error.assert_not_called() task.on_successful_run.assert_called_once_with() task.on_successful_output_validation.assert_called_once_with() def test_dataflow_successful_run_invalid_output_callbacks(self): task = DummyCmdLineTestTask() task.before_run = MagicMock() task.validate_output = MagicMock(return_value=False) task.on_successful_run = MagicMock() task.on_successful_output_validation = MagicMock() task.cleanup_on_error = MagicMock() with self.assertRaises(ValueError): task.run() task.before_run.assert_called_once_with() task.validate_output.assert_called_once_with() task.cleanup_on_error.assert_called_once_with(mock.ANY) task.on_successful_run.assert_called_once_with() task.on_successful_output_validation.assert_not_called() @patch("luigi.contrib.beam_dataflow.subprocess.Popen.wait", return_value=1) @patch("luigi.contrib.beam_dataflow.os._exit", side_effect=OSError) def test_dataflow_failed_run_callbacks(self, popen, os_exit): task = DummyCmdLineTestTask() task.before_run = MagicMock() task.validate_output = MagicMock() task.on_successful_run = MagicMock() task.on_successful_output_validation = MagicMock() task.cleanup_on_error = MagicMock() with self.assertRaises(OSError): task.run() task.before_run.assert_called_once_with() task.validate_output.assert_not_called() task.cleanup_on_error.assert_called_once_with(mock.ANY) task.on_successful_run.assert_not_called() task.on_successful_output_validation.assert_not_called() ================================================ FILE: test/contrib/bigquery_avro_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2019 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ These are the unit tests for the BigQueryLoadAvro class. """ import unittest import avro import avro.schema import pytest from luigi.contrib.bigquery_avro import BigQueryLoadAvro @pytest.mark.gcloud class BigQueryAvroTest(unittest.TestCase): def test_writer_schema_method_existence(self): schema_json = """ { "namespace": "example.avro", "type": "record", "name": "User", "fields": [ {"name": "name", "type": "string"}, {"name": "favorite_number", "type": ["int", "null"]}, {"name": "favorite_color", "type": ["string", "null"]} ] } """ avro_schema = avro.schema.Parse(schema_json) reader = avro.io.DatumReader(avro_schema, avro_schema) actual_schema = BigQueryLoadAvro._get_writer_schema(reader) self.assertEqual(actual_schema, avro_schema, "writer(s) avro_schema attribute not found") # otherwise AttributeError is thrown ================================================ FILE: test/contrib/bigquery_gcloud_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2015 Twitter Inc # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ This is an integration test for the BigQuery-luigi binding. This test requires credentials that can access GCS & access to a bucket below. Follow the directions in the gcloud tools to set up local credentials. """ import json import os import unittest import luigi try: import google.auth import googleapiclient.errors except ImportError: raise unittest.SkipTest("Unable to load googleapiclient module") import avro.schema import pytest from avro.datafile import DataFileWriter from avro.io import DatumWriter from helpers import unittest from luigi.contrib import bigquery, bigquery_avro, gcs from luigi.contrib.bigquery import BigQueryExecutionError from luigi.contrib.gcs import GCSTarget # In order to run this test, you should set your GCS/BigQuery project/bucket. # Unfortunately there's no mock PROJECT_ID = os.environ.get("GCS_TEST_PROJECT_ID", "your_project_id_here") BUCKET_NAME = os.environ.get("GCS_TEST_BUCKET", "your_test_bucket_here") TEST_FOLDER = os.environ.get("TRAVIS_BUILD_ID", "bigquery_test_folder") DATASET_ID = os.environ.get("BQ_TEST_DATASET_ID", "luigi_tests") EU_DATASET_ID = os.environ.get("BQ_TEST_EU_DATASET_ID", "luigi_tests_eu") EU_LOCATION = "EU" US_LOCATION = "US" CREDENTIALS, _ = google.auth.default() def bucket_url(suffix): """ Actually it's bucket + test folder name """ return "gs://{}/{}/{}".format(BUCKET_NAME, TEST_FOLDER, suffix) @pytest.mark.gcloud class TestLoadTask(bigquery.BigQueryLoadTask): source = luigi.Parameter() table = luigi.Parameter() dataset = luigi.Parameter() location = luigi.Parameter(default=None) @property def schema(self): return [ {"mode": "NULLABLE", "name": "field1", "type": "STRING"}, {"mode": "NULLABLE", "name": "field2", "type": "INTEGER"}, ] def source_uris(self): return [self.source] def output(self): return bigquery.BigQueryTarget(PROJECT_ID, self.dataset, self.table, location=self.location) @pytest.mark.gcloud class TestRunQueryTask(bigquery.BigQueryRunQueryTask): query = """ SELECT 'hello' as field1, 2 as field2 """ table = luigi.Parameter() dataset = luigi.Parameter() def output(self): return bigquery.BigQueryTarget(PROJECT_ID, self.dataset, self.table) @pytest.mark.gcloud class TestExtractTask(bigquery.BigQueryExtractTask): source = luigi.Parameter() table = luigi.Parameter() dataset = luigi.Parameter() location = luigi.Parameter(default=None) extract_gcs_file = luigi.Parameter() destination_format = luigi.Parameter(default=bigquery.DestinationFormat.CSV) print_header = luigi.Parameter(default=bigquery.PrintHeader.TRUE) field_delimiter = luigi.Parameter(default=bigquery.FieldDelimiter.COMMA) def output(self): return GCSTarget(bucket_url(self.extract_gcs_file)) def requires(self): return TestLoadTask(source=self.source, dataset=self.dataset, table=self.table) @pytest.mark.gcloud class BigQueryGcloudTest(unittest.TestCase): def setUp(self): self.bq_client = bigquery.BigQueryClient(CREDENTIALS) self.gcs_client = gcs.GCSClient(CREDENTIALS) # Setup GCS input data try: self.gcs_client.client.buckets().insert(project=PROJECT_ID, body={"name": BUCKET_NAME, "location": EU_LOCATION}).execute() except googleapiclient.errors.HttpError as ex: # todo verify that existing dataset is not US if ex.resp.status != 409: # bucket already exists raise self.gcs_client.remove(bucket_url(""), recursive=True) self.gcs_client.mkdir(bucket_url("")) text = "\n".join(map(json.dumps, [{"field1": "hi", "field2": 1}, {"field1": "bye", "field2": 2}])) self.gcs_file = bucket_url(self.id()) self.gcs_client.put_string(text, self.gcs_file) # Setup BigQuery datasets self.table = bigquery.BQTable(project_id=PROJECT_ID, dataset_id=DATASET_ID, table_id=self.id().split(".")[-1], location=None) self.table_eu = bigquery.BQTable(project_id=PROJECT_ID, dataset_id=EU_DATASET_ID, table_id=self.id().split(".")[-1] + "_eu", location=EU_LOCATION) self.addCleanup(self.gcs_client.remove, bucket_url(""), recursive=True) self.addCleanup(self.bq_client.delete_dataset, self.table.dataset) self.addCleanup(self.bq_client.delete_dataset, self.table_eu.dataset) self.bq_client.delete_dataset(self.table.dataset) self.bq_client.delete_dataset(self.table_eu.dataset) self.bq_client.make_dataset(self.table.dataset, body={}) self.bq_client.make_dataset(self.table_eu.dataset, body={}) def test_extract_to_gcs_csv(self): task1 = TestLoadTask(source=self.gcs_file, dataset=self.table.dataset.dataset_id, table=self.table.table_id) task1.run() task2 = TestExtractTask( source=self.gcs_file, dataset=self.table.dataset.dataset_id, table=self.table.table_id, extract_gcs_file=self.id() + "_extract_file", destination_format=bigquery.DestinationFormat.CSV, ) task2.run() self.assertTrue(task2.output().exists) def test_extract_to_gcs_csv_alternate(self): task1 = TestLoadTask(source=self.gcs_file, dataset=self.table.dataset.dataset_id, table=self.table.table_id) task1.run() task2 = TestExtractTask( source=self.gcs_file, dataset=self.table.dataset.dataset_id, table=self.table.table_id, extract_gcs_file=self.id() + "_extract_file", destination_format=bigquery.DestinationFormat.CSV, print_header=bigquery.PrintHeader.FALSE, field_delimiter=bigquery.FieldDelimiter.PIPE, ) task2.run() self.assertTrue(task2.output().exists) def test_extract_to_gcs_json(self): task1 = TestLoadTask(source=self.gcs_file, dataset=self.table.dataset.dataset_id, table=self.table.table_id) task1.run() task2 = TestExtractTask( source=self.gcs_file, dataset=self.table.dataset.dataset_id, table=self.table.table_id, extract_gcs_file=self.id() + "_extract_file", destination_format=bigquery.DestinationFormat.NEWLINE_DELIMITED_JSON, ) task2.run() self.assertTrue(task2.output().exists) def test_extract_to_gcs_avro(self): task1 = TestLoadTask(source=self.gcs_file, dataset=self.table.dataset.dataset_id, table=self.table.table_id) task1.run() task2 = TestExtractTask( source=self.gcs_file, dataset=self.table.dataset.dataset_id, table=self.table.table_id, extract_gcs_file=self.id() + "_extract_file", destination_format=bigquery.DestinationFormat.AVRO, ) task2.run() self.assertTrue(task2.output().exists) def test_load_eu_to_undefined(self): task = TestLoadTask(source=self.gcs_file, dataset=self.table.dataset.dataset_id, table=self.table.table_id, location=EU_LOCATION) self.assertRaises(Exception, task.run) def test_load_us_to_eu(self): task = TestLoadTask(source=self.gcs_file, dataset=self.table_eu.dataset.dataset_id, table=self.table_eu.table_id, location=US_LOCATION) self.assertRaises(Exception, task.run) def test_load_eu_to_eu(self): task = TestLoadTask(source=self.gcs_file, dataset=self.table_eu.dataset.dataset_id, table=self.table_eu.table_id, location=EU_LOCATION) task.run() self.assertTrue(self.bq_client.dataset_exists(self.table_eu)) self.assertTrue(self.bq_client.table_exists(self.table_eu)) self.assertIn(self.table_eu.dataset_id, list(self.bq_client.list_datasets(self.table_eu.project_id))) self.assertIn(self.table_eu.table_id, list(self.bq_client.list_tables(self.table_eu.dataset))) def test_load_undefined_to_eu(self): task = TestLoadTask(source=self.gcs_file, dataset=self.table_eu.dataset.dataset_id, table=self.table_eu.table_id) task.run() self.assertTrue(self.bq_client.dataset_exists(self.table_eu)) self.assertTrue(self.bq_client.table_exists(self.table_eu)) self.assertIn(self.table_eu.dataset_id, list(self.bq_client.list_datasets(self.table_eu.project_id))) self.assertIn(self.table_eu.table_id, list(self.bq_client.list_tables(self.table_eu.dataset))) def test_load_new_eu_dataset(self): self.bq_client.delete_dataset(self.table.dataset) self.bq_client.delete_dataset(self.table_eu.dataset) self.assertFalse(self.bq_client.dataset_exists(self.table_eu)) task = TestLoadTask(source=self.gcs_file, dataset=self.table_eu.dataset.dataset_id, table=self.table_eu.table_id, location=EU_LOCATION) task.run() self.assertTrue(self.bq_client.dataset_exists(self.table_eu)) self.assertTrue(self.bq_client.table_exists(self.table_eu)) self.assertIn(self.table_eu.dataset_id, list(self.bq_client.list_datasets(self.table_eu.project_id))) self.assertIn(self.table_eu.table_id, list(self.bq_client.list_tables(self.table_eu.dataset))) def test_copy(self): task = TestLoadTask(source=self.gcs_file, dataset=self.table.dataset.dataset_id, table=self.table.table_id) task.run() self.assertTrue(self.bq_client.dataset_exists(self.table)) self.assertTrue(self.bq_client.table_exists(self.table)) self.assertIn(self.table.dataset_id, list(self.bq_client.list_datasets(self.table.project_id))) self.assertIn(self.table.table_id, list(self.bq_client.list_tables(self.table.dataset))) new_table = self.table._replace(table_id=self.table.table_id + "_copy") self.bq_client.copy(source_table=self.table, dest_table=new_table) self.assertTrue(self.bq_client.table_exists(new_table)) self.bq_client.delete_table(new_table) self.assertFalse(self.bq_client.table_exists(new_table)) def test_table_uri(self): intended_uri = "bq://" + PROJECT_ID + "/" + DATASET_ID + "/" + self.table.table_id self.assertTrue(self.table.uri == intended_uri) def test_run_query(self): task = TestRunQueryTask(table=self.table.table_id, dataset=self.table.dataset.dataset_id) task._BIGQUERY_CLIENT = self.bq_client task.run() self.assertTrue(self.bq_client.table_exists(self.table)) def test_run_successful_job(self): body = {"configuration": {"query": {"query": "select count(*) from unnest([1,2,3])"}}} job_id = self.bq_client.run_job(PROJECT_ID, body) self.assertIsNotNone(job_id) self.assertNotEqual("", job_id) def test_run_failing_job(self): body = {"configuration": {"query": {"query": "this is not a valid query"}}} self.assertRaises(BigQueryExecutionError, lambda: self.bq_client.run_job(PROJECT_ID, body)) @pytest.mark.gcloud class BigQueryLoadAvroTest(unittest.TestCase): def _produce_test_input(self): schema = avro.schema.parse(""" { "type":"record", "name":"TrackEntity2", "namespace":"com.spotify.entity.schema", "doc":"Track entity merged from various sources", "fields":[ { "name":"map_record", "type":{ "type":"map", "values":{ "type":"record", "name":"MapNestedRecordObj", "doc":"Nested Record in a map doc", "fields":[ { "name":"element1", "type":"string", "doc":"element 1 doc" }, { "name":"element2", "type":[ "null", "string" ], "doc":"element 2 doc" } ] } }, "doc":"doc for map" }, { "name":"additional", "type":{ "type":"map", "values":"string" }, "doc":"doc for second map record" }, { "name":"track_gid", "type":"string", "doc":"Track GID in hexadecimal string" }, { "name":"track_uri", "type":"string", "doc":"Track URI in base62 string" }, { "name":"Suit", "type":{ "type":"enum", "name":"Suit", "doc":"enum documentation broz", "symbols":[ "SPADES", "HEARTS", "DIAMONDS", "CLUBS" ] } }, { "name":"FakeRecord", "type":{ "type":"record", "name":"FakeRecord", "namespace":"com.spotify.data.types.coolType", "doc":"My Fake Record doc", "fields":[ { "name":"coolName", "type":"string", "doc":"Cool Name doc" } ] } }, { "name":"master_metadata", "type":[ "null", { "type":"record", "name":"MasterMetadata", "namespace":"com.spotify.data.types.metadata", "doc":"metadoc", "fields":[ { "name":"track", "type":[ "null", { "type":"record", "name":"Track", "doc":"Sqoop import of track", "fields":[ { "name":"id", "type":[ "null", "int" ], "doc":"id description field", "default":null, "columnName":"id", "sqlType":"4" }, { "name":"name", "type":[ "null", "string" ], "doc":"name description field", "default":null, "columnName":"name", "sqlType":"12" } ], "tableName":"track" } ], "default":null } ] } ] }, { "name":"children", "type":{ "type":"array", "items":{ "type":"record", "name":"Child", "doc":"array of children documentation", "fields":[ { "name":"name", "type":"string", "doc":"my specific child\'s doc" } ] } } } ] }""") self.addCleanup(os.remove, "tmp.avro") writer = DataFileWriter(open("tmp.avro", "wb"), DatumWriter(), schema) writer.append( { "track_gid": "Cool guid", "map_record": {"Cool key": {"element1": "element 1 data", "element2": "element 2 data"}}, "additional": {"key1": "value1"}, "master_metadata": {"track": {"id": 1, "name": "Cool Track Name"}}, "track_uri": "Totally a url here", "FakeRecord": {"coolName": "Cool Fake Record Name"}, "Suit": "DIAMONDS", "children": [{"name": "Bob"}, {"name": "Joe"}], } ) writer.close() self.gcs_client.put("tmp.avro", self.gcs_dir_url + "/tmp.avro") def setUp(self): self.gcs_client = gcs.GCSClient(CREDENTIALS) self.bq_client = bigquery.BigQueryClient(CREDENTIALS) self.table_id = "avro_bq_table" self.gcs_dir_url = "gs://" + BUCKET_NAME + "/foo" self.addCleanup(self.gcs_client.remove, self.gcs_dir_url) self.addCleanup(self.bq_client.delete_dataset, bigquery.BQDataset(PROJECT_ID, DATASET_ID, EU_LOCATION)) self._produce_test_input() def test_load_avro_dir_and_propagate_doc(self): class BigQueryLoadAvroTestInput(luigi.ExternalTask): def output(_): return gcs.GCSTarget(self.gcs_dir_url) class BigQueryLoadAvroTestTask(bigquery_avro.BigQueryLoadAvro): def requires(_): return BigQueryLoadAvroTestInput() def output(_): return bigquery.BigQueryTarget(PROJECT_ID, DATASET_ID, self.table_id, location=EU_LOCATION) task = BigQueryLoadAvroTestTask() self.assertFalse(task.complete()) task.run() self.assertTrue(task.complete()) table = self.bq_client.client.tables().get(projectId=PROJECT_ID, datasetId=DATASET_ID, tableId=self.table_id).execute() self.assertEqual(table["description"], "Track entity merged from various sources") # First map self.assertEqual(table["schema"]["fields"][0]["description"], "doc for map") # key self.assertFalse("description" in table["schema"]["fields"][0]["fields"][0]) # Value self.assertEqual(table["schema"]["fields"][0]["fields"][1]["description"], "Nested Record in a map doc") # Value record data self.assertEqual(table["schema"]["fields"][0]["fields"][1]["fields"][0]["description"], "element 1 doc") self.assertEqual(table["schema"]["fields"][0]["fields"][1]["fields"][1]["description"], "element 2 doc") # Second map self.assertEqual(table["schema"]["fields"][1]["description"], "doc for second map record") # key self.assertFalse("description" in table["schema"]["fields"][1]["fields"][0]) # Value self.assertFalse("description" in table["schema"]["fields"][1]["fields"][1]) # Several top level Primitive and Enums self.assertEqual(table["schema"]["fields"][2]["description"], "Track GID in hexadecimal string") self.assertEqual(table["schema"]["fields"][3]["description"], "Track URI in base62 string") self.assertEqual(table["schema"]["fields"][4]["description"], "enum documentation broz") # Nested Record containing primitive self.assertEqual(table["schema"]["fields"][5]["description"], "My Fake Record doc") self.assertEqual(table["schema"]["fields"][5]["fields"][0]["description"], "Cool Name doc") # Union with internal Record self.assertEqual(table["schema"]["fields"][6]["description"], "metadoc") self.assertEqual(table["schema"]["fields"][6]["fields"][0]["description"], "Sqoop import of track") self.assertEqual(table["schema"]["fields"][6]["fields"][0]["fields"][0]["description"], "id description field") self.assertEqual(table["schema"]["fields"][6]["fields"][0]["fields"][1]["description"], "name description field") # Array of Primitive self.assertEqual(table["schema"]["fields"][7]["description"], "array of children documentation") self.assertEqual(table["schema"]["fields"][7]["fields"][0]["description"], "my specific child's doc") ================================================ FILE: test/contrib/bigquery_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2019 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ These are the unit tests for the BigQueryLoadAvro class. """ import unittest import mock import pytest from mock.mock import MagicMock from luigi.contrib import bigquery from luigi.contrib.bigquery import BigQueryClient, BigQueryExtractTask, BigQueryLoadTask, BigQueryRunQueryTask, BigQueryTarget, BQDataset from luigi.contrib.gcs import GCSTarget @pytest.mark.gcloud class BigQueryLoadTaskTest(unittest.TestCase): @mock.patch("luigi.contrib.bigquery.BigQueryClient.run_job") def test_configure_job(self, run_job): class MyBigQueryLoadTask(BigQueryLoadTask): def source_uris(self): return ["gs://_"] def configure_job(self, configuration): configuration["load"]["destinationTableProperties"] = {"description": "Nice table"} return configuration def output(self): return BigQueryTarget(project_id="proj", dataset_id="ds", table_id="t") job = MyBigQueryLoadTask() job.run() expected_body = { "configuration": { "load": { "destinationTable": {"projectId": "proj", "datasetId": "ds", "tableId": "t"}, "encoding": "UTF-8", "sourceFormat": "NEWLINE_DELIMITED_JSON", "writeDisposition": "WRITE_EMPTY", "sourceUris": ["gs://_"], "maxBadRecords": 0, "ignoreUnknownValues": False, "autodetect": True, "destinationTableProperties": {"description": "Nice table"}, } } } run_job.assert_called_with("proj", expected_body, dataset=BQDataset("proj", "ds", None)) @pytest.mark.gcloud class BigQueryRunQueryTaskTest(unittest.TestCase): @mock.patch("luigi.contrib.bigquery.BigQueryClient.run_job") def test_configure_job(self, run_job): class MyBigQueryRunQuery(BigQueryRunQueryTask): query = "SELECT @thing" use_legacy_sql = False def configure_job(self, configuration): configuration["query"]["parameterMode"] = "NAMED" configuration["query"]["queryParameters"] = {"name": "thing", "parameterType": {"type": "STRING"}, "parameterValue": {"value": "Nice Thing"}} return configuration def output(self): return BigQueryTarget(project_id="proj", dataset_id="ds", table_id="t") job = MyBigQueryRunQuery() job.run() expected_body = { "configuration": { "query": { "query": "SELECT @thing", "priority": "INTERACTIVE", "destinationTable": {"projectId": "proj", "datasetId": "ds", "tableId": "t"}, "allowLargeResults": True, "createDisposition": "CREATE_IF_NEEDED", "writeDisposition": "WRITE_TRUNCATE", "flattenResults": True, "userDefinedFunctionResources": [], "useLegacySql": False, "parameterMode": "NAMED", "queryParameters": {"name": "thing", "parameterType": {"type": "STRING"}, "parameterValue": {"value": "Nice Thing"}}, } } } run_job.assert_called_with("proj", expected_body, dataset=BQDataset("proj", "ds", None)) @pytest.mark.gcloud class BigQueryExtractTaskTest(unittest.TestCase): @mock.patch("luigi.contrib.bigquery.BigQueryClient.run_job") def test_configure_job(self, run_job): class MyBigQueryExtractTask(BigQueryExtractTask): destination_format = "AVRO" def configure_job(self, configuration): configuration["extract"]["useAvroLogicalTypes"] = True return configuration def input(self): return BigQueryTarget(project_id="proj", dataset_id="ds", table_id="t") def output(self): return GCSTarget("gs://_") job = MyBigQueryExtractTask() job.run() expected_body = { "configuration": { "extract": { "sourceTable": {"projectId": "proj", "datasetId": "ds", "tableId": "t"}, "destinationUris": ["gs://_"], "destinationFormat": "AVRO", "compression": "NONE", "useAvroLogicalTypes": True, } } } run_job.assert_called_with("proj", expected_body, dataset=BQDataset("proj", "ds", None)) @pytest.mark.gcloud class BigQueryClientTest(unittest.TestCase): def test_retry_succeeds_on_second_attempt(self): try: from googleapiclient import errors except ImportError: raise unittest.SkipTest("Unable to load googleapiclient module") client = MagicMock(spec=BigQueryClient) attempts = 0 @bigquery.bq_retry def fail_once(bq_client): nonlocal attempts attempts += 1 if attempts == 1: raise errors.HttpError( resp=MagicMock(status=500), content=b'{"error": {"message": "stub"}', ) else: return MagicMock(status=200) response = fail_once(client) client._initialise_client.assert_called_once() self.assertEqual(attempts, 2) self.assertEqual(response.status, 200) ================================================ FILE: test/contrib/cascading_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import pytest from helpers import unittest import luigi.target from luigi.contrib.target import CascadingClient @pytest.mark.contrib class CascadingClientTest(unittest.TestCase): def setUp(self): class FirstClient: def exists(self, pos_arg, kw_arg="first"): if pos_arg < 10: return pos_arg elif pos_arg < 20: return kw_arg elif kw_arg == "raise_fae": raise luigi.target.FileAlreadyExists("oh noes!") else: raise Exception() class SecondClient: def exists(self, pos_arg, other_kw_arg="second", kw_arg="for-backwards-compatibility"): if pos_arg < 30: return -pos_arg elif pos_arg < 40: return other_kw_arg else: raise Exception() self.clients = [FirstClient(), SecondClient()] self.client = CascadingClient(self.clients) def test_successes(self): self.assertEqual(5, self.client.exists(5)) self.assertEqual("yay", self.client.exists(15, kw_arg="yay")) def test_fallbacking(self): self.assertEqual(-25, self.client.exists(25)) self.assertEqual("lol", self.client.exists(35, kw_arg="yay", other_kw_arg="lol")) # Note: the first method don't accept the other keyword argument self.assertEqual(-15, self.client.exists(15, kw_arg="yay", other_kw_arg="lol")) def test_failings(self): self.assertRaises(Exception, lambda: self.client.exists(45)) self.assertRaises(AttributeError, lambda: self.client.mkdir()) def test_FileAlreadyExists_propagation(self): self.assertRaises(luigi.target.FileAlreadyExists, lambda: self.client.exists(25, kw_arg="raise_fae")) def test_method_names_kwarg(self): self.client = CascadingClient(self.clients, method_names=[]) self.assertRaises(AttributeError, lambda: self.client.exists()) self.client = CascadingClient(self.clients, method_names=["exists"]) self.assertEqual(5, self.client.exists(5)) ================================================ FILE: test/contrib/datadog_metric_test.py ================================================ # -*- coding: utf-8 -*- import time import mock from helpers import unittest from luigi.contrib.datadog_metric import DatadogMetricsCollector from luigi.metrics import MetricsCollectors from luigi.scheduler import Scheduler WORKER = "myworker" class DatadogMetricTest(unittest.TestCase): def setUp(self): self.mockDatadog() self.time = time.time self.collector = DatadogMetricsCollector() self.s = Scheduler(metrics_collector=MetricsCollectors.datadog) def tearDown(self): self.unMockDatadog() if time.time != self.time: time.time = self.time def startTask(self, scheduler=None): if scheduler: s = scheduler else: s = self.s s.add_task(worker=WORKER, task_id="DDTaskID", family="DDTaskName") task = s._state.get_task("DDTaskID") task.time_running = 0 return task def mockDatadog(self): self.create_patcher = mock.patch("datadog.api.Event.create") self.mock_create = self.create_patcher.start() self.increment_patcher = mock.patch("datadog.statsd.increment") self.mock_increment = self.increment_patcher.start() self.gauge_patcher = mock.patch("datadog.statsd.gauge") self.mock_gauge = self.gauge_patcher.start() def unMockDatadog(self): self.create_patcher.stop() self.increment_patcher.stop() self.gauge_patcher.stop() def setTime(self, t): time.time = lambda: t def test_send_event_on_task_started(self): task = self.startTask() self.collector.handle_task_started(task) self.mock_create.assert_called_once_with( alert_type="info", priority="low", tags=["task_name:DDTaskName", "task_state:STARTED", "environment:development", "application:luigi"], text="A task has been started in the pipeline named: DDTaskName", title="Luigi: A task has been started!", ) def test_send_increment_on_task_started(self): task = self.startTask() self.collector.handle_task_started(task) self.mock_increment.assert_called_once_with("luigi.task.started", 1, tags=["task_name:DDTaskName", "environment:development", "application:luigi"]) def test_send_event_on_task_failed(self): task = self.startTask() self.collector.handle_task_failed(task) self.mock_create.assert_called_once_with( alert_type="error", priority="normal", tags=["task_name:DDTaskName", "task_state:FAILED", "environment:development", "application:luigi"], text="A task has failed in the pipeline named: DDTaskName", title="Luigi: A task has failed!", ) def test_send_increment_on_task_failed(self): task = self.startTask() self.collector.handle_task_failed(task) self.mock_increment.assert_called_once_with("luigi.task.failed", 1, tags=["task_name:DDTaskName", "environment:development", "application:luigi"]) def test_send_event_on_task_disabled(self): s = Scheduler(metrics_collector=MetricsCollectors.datadog, disable_persist=10, retry_count=2, disable_window=2) task = self.startTask(scheduler=s) self.collector.handle_task_disabled(task, s._config) self.mock_create.assert_called_once_with( alert_type="error", priority="normal", tags=["task_name:DDTaskName", "task_state:DISABLED", "environment:development", "application:luigi"], text="A task has been disabled in the pipeline named: DDTaskName. " + "The task has failed 2 times in the last 2 seconds" + ", so it is being disabled for 10 seconds.", title="Luigi: A task has been disabled!", ) def test_send_increment_on_task_disabled(self): task = self.startTask() self.collector.handle_task_disabled(task, self.s._config) self.mock_increment.assert_called_once_with("luigi.task.disabled", 1, tags=["task_name:DDTaskName", "environment:development", "application:luigi"]) def test_send_event_on_task_done(self): task = self.startTask() self.collector.handle_task_done(task) self.mock_create.assert_called_once_with( alert_type="info", priority="low", tags=["task_name:DDTaskName", "task_state:DONE", "environment:development", "application:luigi"], text="A task has completed in the pipeline named: DDTaskName", title="Luigi: A task has been completed!", ) def test_send_increment_on_task_done(self): task = self.startTask() self.collector.handle_task_done(task) self.mock_increment.assert_called_once_with("luigi.task.done", 1, tags=["task_name:DDTaskName", "environment:development", "application:luigi"]) def test_send_gauge_on_task_done(self): self.setTime(0) task = self.startTask() self.collector.handle_task_done(task) self.mock_gauge.assert_called_once_with("luigi.task.execution_time", 0, tags=["task_name:DDTaskName", "environment:development", "application:luigi"]) ================================================ FILE: test/contrib/dataproc_test.py ================================================ """This is an integration test for the Dataproc-luigi binding. This test requires credentials that can access GCS & access to a bucket below. Follow the directions in the gcloud tools to set up local credentials. """ import unittest try: import google.auth from googleapiclient import discovery from luigi.contrib import dataproc default_credentials, _ = google.auth.default() default_client = discovery.build("dataproc", "v1", cache_discovery=False, credentials=default_credentials) dataproc.set_dataproc_client(default_client) except ImportError: raise unittest.SkipTest("Unable to load google cloud dependencies") import os import time import pytest import luigi # In order to run this test, you should set these to your GCS project. # Unfortunately there's no mock PROJECT_ID = os.environ.get("DATAPROC_TEST_PROJECT_ID", "your_project_id_here") CLUSTER_NAME = os.environ.get("DATAPROC_TEST_CLUSTER", "unit-test-cluster") REGION = os.environ.get("DATAPROC_REGION", "global") IMAGE_VERSION = "1-0" class _DataprocBaseTestCase(unittest.TestCase): def setUp(self): pass def tearDown(self): pass @pytest.mark.gcloud class DataprocTaskTest(_DataprocBaseTestCase): def test_1_create_cluster(self): success = luigi.run( ["--local-scheduler", "--no-lock", "CreateDataprocClusterTask", "--gcloud-project-id=" + PROJECT_ID, "--dataproc-cluster-name=" + CLUSTER_NAME] ) self.assertTrue(success) def test_2_create_cluster_should_notice_existing_cluster_and_return_immediately(self): job_start = time.time() success = luigi.run( ["--local-scheduler", "--no-lock", "CreateDataprocClusterTask", "--gcloud-project-id=" + PROJECT_ID, "--dataproc-cluster-name=" + CLUSTER_NAME] ) self.assertTrue(success) self.assertLess(time.time() - job_start, 3) def test_3_submit_minimal_job(self): # The job itself will fail because the job files don't exist # We don't care, because then we would be testing spark # We care the job was submitted correctly, so that's what we test luigi.run( [ "--local-scheduler", "--no-lock", "DataprocSparkTask", "--gcloud-project-id=" + PROJECT_ID, "--dataproc-cluster-name=" + CLUSTER_NAME, "--main-class=my.MinimalMainClass", ] ) response = dataproc.get_dataproc_client().projects().regions().jobs().list(projectId=PROJECT_ID, region=REGION, clusterName=CLUSTER_NAME).execute() lastJob = response["jobs"][0]["sparkJob"] self.assertEqual(lastJob["mainClass"], "my.MinimalMainClass") def test_4_submit_spark_job(self): # The job itself will fail because the job files don't exist # We don't care, because then we would be testing spark # We care the job was submitted correctly, so that's what we test luigi.run( [ "--local-scheduler", "--no-lock", "DataprocSparkTask", "--gcloud-project-id=" + PROJECT_ID, "--dataproc-cluster-name=" + CLUSTER_NAME, "--main-class=my.MainClass", "--jars=one.jar,two.jar", "--job-args=foo,bar", ] ) response = dataproc.get_dataproc_client().projects().regions().jobs().list(projectId=PROJECT_ID, region=REGION, clusterName=CLUSTER_NAME).execute() lastJob = response["jobs"][0]["sparkJob"] self.assertEqual(lastJob["mainClass"], "my.MainClass") self.assertEqual(lastJob["jarFileUris"], ["one.jar", "two.jar"]) self.assertEqual(lastJob["args"], ["foo", "bar"]) def test_5_submit_pyspark_job(self): # The job itself will fail because the job files don't exist # We don't care, because then we would be testing pyspark # We care the job was submitted correctly, so that's what we test luigi.run( [ "--local-scheduler", "--no-lock", "DataprocPysparkTask", "--gcloud-project-id=" + PROJECT_ID, "--dataproc-cluster-name=" + CLUSTER_NAME, "--job-file=main_job.py", "--extra-files=extra1.py,extra2.py", "--job-args=foo,bar", ] ) response = dataproc.get_dataproc_client().projects().regions().jobs().list(projectId=PROJECT_ID, region=REGION, clusterName=CLUSTER_NAME).execute() lastJob = response["jobs"][0]["pysparkJob"] self.assertEqual(lastJob["mainPythonFileUri"], "main_job.py") self.assertEqual(lastJob["pythonFileUris"], ["extra1.py", "extra2.py"]) self.assertEqual(lastJob["args"], ["foo", "bar"]) def test_6_delete_cluster(self): success = luigi.run( ["--local-scheduler", "--no-lock", "DeleteDataprocClusterTask", "--gcloud-project-id=" + PROJECT_ID, "--dataproc-cluster-name=" + CLUSTER_NAME] ) self.assertTrue(success) def test_7_delete_cluster_should_return_immediately_if_no_cluster(self): job_start = time.time() success = luigi.run( ["--local-scheduler", "--no-lock", "DeleteDataprocClusterTask", "--gcloud-project-id=" + PROJECT_ID, "--dataproc-cluster-name=" + CLUSTER_NAME] ) self.assertTrue(success) self.assertLess(time.time() - job_start, 3) def test_8_create_cluster_image_version(self): success = luigi.run( [ "--local-scheduler", "--no-lock", "CreateDataprocClusterTask", "--gcloud-project-id=" + PROJECT_ID, "--dataproc-cluster-name=" + CLUSTER_NAME + "-" + IMAGE_VERSION, "--image-version=1.0", ] ) self.assertTrue(success) def test_9_delete_cluster_image_version(self): success = luigi.run( [ "--local-scheduler", "--no-lock", "DeleteDataprocClusterTask", "--gcloud-project-id=" + PROJECT_ID, "--dataproc-cluster-name=" + CLUSTER_NAME + "-" + IMAGE_VERSION, ] ) self.assertTrue(success) ================================================ FILE: test/contrib/docker_runner_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2017 Open Targets # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ Tests for Docker container wrapper for Luigi. Requires: - docker: ``pip install docker`` Written and maintained by Andrea Pierleoni (@apierleoni). Contributions by Eliseo Papa (@elipapa) """ import logging import tempfile from tempfile import NamedTemporaryFile import pytest from helpers import unittest import luigi from luigi.contrib.docker_runner import DockerTask logger = logging.getLogger("luigi-interface") try: import docker from docker.errors import ContainerError, ImageNotFound client = docker.from_env() client.version() except ImportError: raise unittest.SkipTest("Unable to load docker module") except Exception: raise unittest.SkipTest("Unable to connect to docker daemon") tempfile.tempdir = "/tmp" # set it explicitly to make it work out of the box in mac os local_file = NamedTemporaryFile() local_file.write(b"this is a test file\n") local_file.flush() class SuccessJob(DockerTask): image = "busybox:latest" name = "SuccessJob" class FailJobImageNotFound(DockerTask): image = "image-does-not-exists" name = "FailJobImageNotFound" class FailJobContainer(DockerTask): image = "busybox" name = "FailJobContainer" command = "cat this-file-does-not-exist" class WriteToTmpDir(DockerTask): image = "busybox" name = "WriteToTmpDir" container_tmp_dir = "/tmp/luigi-test" command = "test -d /tmp/luigi-test" # command = 'test -d $LUIGI_TMP_DIR'# && echo ok >$LUIGI_TMP_DIR/test' class MountLocalFileAsVolume(DockerTask): image = "busybox" name = "MountLocalFileAsVolume" # volumes= {'/tmp/local_file_test': {'bind': local_file.name, 'mode': 'rw'}} binds = [local_file.name + ":/tmp/local_file_test"] command = "test -f /tmp/local_file_test" class MountLocalFileAsVolumeWithParam(DockerTask): dummyopt = luigi.Parameter() image = "busybox" name = "MountLocalFileAsVolumeWithParam" binds = [local_file.name + ":/tmp/local_file_test"] command = "test -f /tmp/local_file_test" class MountLocalFileAsVolumeWithParamRedefProperties(DockerTask): dummyopt = luigi.Parameter() image = "busybox" name = "MountLocalFileAsVolumeWithParamRedef" @property def binds(self): return [local_file.name + ":/tmp/local_file_test" + self.dummyopt] @property def command(self): return "test -f /tmp/local_file_test" + self.dummyopt def complete(self): return True class MultipleDockerTask(luigi.WrapperTask): """because the volumes property is defined as a list, spinning multiple containers led to conflict in the volume binds definition, with multiple host directories pointing to the same container directory""" def requires(self): return [MountLocalFileAsVolumeWithParam(dummyopt=opt) for opt in ["one", "two", "three"]] class MultipleDockerTaskRedefProperties(luigi.WrapperTask): def requires(self): return [MountLocalFileAsVolumeWithParamRedefProperties(dummyopt=opt) for opt in ["one", "two", "three"]] @pytest.mark.contrib class TestDockerTask(unittest.TestCase): # def tearDown(self): # local_file.close() def test_success_job(self): success = SuccessJob() luigi.build([success], local_scheduler=True) self.assertTrue(success) def test_temp_dir_creation(self): writedir = WriteToTmpDir() writedir.run() def test_local_file_mount(self): localfile = MountLocalFileAsVolume() localfile.run() def test_fail_job_image_not_found(self): fail = FailJobImageNotFound() self.assertRaises(ImageNotFound, fail.run) def test_fail_job_container(self): fail = FailJobContainer() self.assertRaises(ContainerError, fail.run) def test_multiple_jobs(self): worked = MultipleDockerTask() luigi.build([worked], local_scheduler=True) self.assertTrue(worked) def test_multiple_jobs2(self): worked = MultipleDockerTaskRedefProperties() luigi.build([worked], local_scheduler=True) self.assertTrue(worked) ================================================ FILE: test/contrib/dropbox_test.py ================================================ import bz2 import os import tempfile import unittest import uuid from datetime import datetime import pytest import luigi from luigi.format import NopFormat try: import dropbox import dropbox.exceptions from luigi.contrib.dropbox import DropboxClient except ImportError: raise unittest.SkipTest("DropboxTarget and DropboxClient will not be tested. Dropbox library is not installed") DROPBOX_APP_TOKEN = os.environ.get("DROPBOX_APP_TOKEN") if not DROPBOX_APP_TOKEN: raise ValueError( "In order to test DropboxTarget and DropboxClient, the DROPBOX_APP_TOKEN environment variable " "must contain a valid Dropbox OAuth2 Token. \n" "Get one at https://www.dropbox.com/developers/apps " ) DROPBOX_TEST_PATH = "/luigi-tests/luigi-test-" + datetime.now().strftime("%Y.%m.%d-%H.%M.%S") + str(uuid.uuid4()) # These paths will be created in the test set-up DROPBOX_TEST_SIMPLE_DIR = DROPBOX_TEST_PATH + "/dir2" DROPBOX_TEST_FILE_IN_DIR = DROPBOX_TEST_SIMPLE_DIR + "/test2.txt" DROPBOX_TEST_SIMPLE_FILE = DROPBOX_TEST_PATH + "/test.txt" DROPBOX_TEST_DIR_TO_DELETE = DROPBOX_TEST_PATH + "/dir_to_delete" DROPBOX_TEST_FILE_TO_DELETE_2 = DROPBOX_TEST_DIR_TO_DELETE + "/test3.2.txt" DROPBOX_TEST_FILE_TO_DELETE_1 = DROPBOX_TEST_DIR_TO_DELETE + "/test3.1.txt" DROPBOX_TEST_FILE_TO_COPY_ORIG = DROPBOX_TEST_PATH + "/dir4/test4.txt" DROPBOX_TEST_FILE_TO_MOVE_ORIG = DROPBOX_TEST_PATH + "/dir3/test3.txt" # All the following paths will be used by the tests DROPBOX_TEST_SMALL_FILE = DROPBOX_TEST_PATH + "/dir/small.txt" DROPBOX_TEST_LARGE_FILE = DROPBOX_TEST_PATH + "/dir/big.bin" DROPBOX_TEST_FILE_TO_COPY_DEST = DROPBOX_TEST_PATH + "/dir_four/test_four.txt" DROPBOX_TEST_FILE_TO_MOVE_DEST = DROPBOX_TEST_PATH + "/dir_three/test_three.txt" DROPBOX_TEST_OUTER_DIR_TO_CREATE = DROPBOX_TEST_PATH + "/new_folder" DROPBOX_TEST_DIR_TO_CREATE = DROPBOX_TEST_OUTER_DIR_TO_CREATE + "/inner_folder" DROPBOX_TEST_SIMPLE_DIR_TO_CREATE_AND_DELETE = DROPBOX_TEST_PATH + "/another_new_folder" DROPBOX_TEST_FILE_TO_UPLOAD_BZIP2 = DROPBOX_TEST_PATH + "/bin.file" DROPBOX_TEST_FILE_TO_UPLOAD_TEXT = DROPBOX_TEST_PATH + "/text.txt" DROPBOX_TEST_FILE_TO_UPLOAD_BIN = DROPBOX_TEST_PATH + "/file.bin" DROPBOX_TEST_FILE_TO_UPLOAD_LARGE = DROPBOX_TEST_PATH + "/file.blob" DROPBOX_TEST_NON_EXISTING_FILE = DROPBOX_TEST_SIMPLE_DIR + "ajdlkajfal" @pytest.mark.dropbox class TestClientDropbox(unittest.TestCase): def setUp(self): self.luigiconn = DropboxClient(DROPBOX_APP_TOKEN) self.dropbox_api = dropbox.dropbox_client.Dropbox(DROPBOX_APP_TOKEN) self.dropbox_api.files_upload(b"hello", DROPBOX_TEST_SIMPLE_FILE) self.dropbox_api.files_upload(b"hello2", DROPBOX_TEST_FILE_IN_DIR) self.dropbox_api.files_upload(b"hello3", DROPBOX_TEST_FILE_TO_MOVE_ORIG) self.dropbox_api.files_upload(b"hello4", DROPBOX_TEST_FILE_TO_COPY_ORIG) self.dropbox_api.files_upload(b"hello3.1", DROPBOX_TEST_FILE_TO_DELETE_1) self.dropbox_api.files_upload(b"hello3.2", DROPBOX_TEST_FILE_TO_DELETE_2) def tearDown(self): self.dropbox_api.files_delete_v2(DROPBOX_TEST_PATH) self.dropbox_api._session.close() def test_exists(self): self.assertTrue(self.luigiconn.exists("/")) self.assertTrue(self.luigiconn.exists(DROPBOX_TEST_PATH)) self.assertTrue(self.luigiconn.exists(DROPBOX_TEST_SIMPLE_DIR)) self.assertTrue(self.luigiconn.exists(DROPBOX_TEST_SIMPLE_DIR + "/")) self.assertTrue(self.luigiconn.exists(DROPBOX_TEST_SIMPLE_FILE)) self.assertFalse(self.luigiconn.exists(DROPBOX_TEST_SIMPLE_FILE + "/")) self.assertFalse(self.luigiconn.exists(DROPBOX_TEST_NON_EXISTING_FILE)) def test_listdir_simple(self): list_of_dirs = self.luigiconn.listdir(DROPBOX_TEST_PATH) self.assertTrue("/" not in list_of_dirs) self.assertTrue(DROPBOX_TEST_PATH in list_of_dirs) self.assertTrue(DROPBOX_TEST_SIMPLE_FILE in list_of_dirs) # we verify recursivity def test_listdir_simple_with_one_slash(self): list_of_dirs = self.luigiconn.listdir(DROPBOX_TEST_PATH + "/") self.assertTrue("/" not in list_of_dirs) self.assertTrue(DROPBOX_TEST_PATH in list_of_dirs) self.assertTrue(DROPBOX_TEST_SIMPLE_FILE in list_of_dirs) # we verify recursivity def test_listdir_multiple(self): list_of_dirs = self.luigiconn.listdir(DROPBOX_TEST_PATH, limit=2) self.assertTrue("/" not in list_of_dirs) self.assertTrue(DROPBOX_TEST_PATH in list_of_dirs) self.assertTrue(DROPBOX_TEST_SIMPLE_FILE in list_of_dirs) # we verify recursivity def test_listdir_nonexisting(self): with self.assertRaises(dropbox.exceptions.ApiError): self.luigiconn.listdir(DROPBOX_TEST_NON_EXISTING_FILE) def test_remove(self): # We remove File_to_delete_1. We make sure it is the only file that gets deleted self.assertTrue(self.luigiconn.exists(DROPBOX_TEST_FILE_TO_DELETE_1)) self.assertTrue(self.luigiconn.exists(DROPBOX_TEST_FILE_TO_DELETE_2)) self.assertTrue(self.luigiconn.remove(DROPBOX_TEST_FILE_TO_DELETE_1)) self.assertFalse(self.luigiconn.exists(DROPBOX_TEST_FILE_TO_DELETE_1)) self.assertTrue(self.luigiconn.exists(DROPBOX_TEST_FILE_TO_DELETE_2)) # We remove a directory, we make sure that the files that were in the directory are also deleted self.luigiconn.remove(DROPBOX_TEST_DIR_TO_DELETE) self.assertFalse(self.luigiconn.exists(DROPBOX_TEST_FILE_TO_DELETE_2)) # We make sure that we return False when we fail to remove a non-existing path self.assertFalse(self.luigiconn.remove(DROPBOX_TEST_NON_EXISTING_FILE)) self.assertFalse(self.luigiconn.remove(DROPBOX_TEST_NON_EXISTING_FILE + "/")) def test_mkdir_new_dir(self): self.assertFalse(self.luigiconn.exists(DROPBOX_TEST_DIR_TO_CREATE)) self.assertFalse(self.luigiconn.exists(DROPBOX_TEST_OUTER_DIR_TO_CREATE)) self.luigiconn.mkdir(DROPBOX_TEST_DIR_TO_CREATE) self.assertTrue(self.luigiconn.isdir(DROPBOX_TEST_OUTER_DIR_TO_CREATE)) self.assertTrue(self.luigiconn.isdir(DROPBOX_TEST_DIR_TO_CREATE)) self.assertTrue(self.luigiconn.isdir(DROPBOX_TEST_DIR_TO_CREATE)) def aux_lifecycle_of_directory(self, path): # Initially, the directory does not exists self.assertFalse(self.luigiconn.exists(path)) self.assertFalse(self.luigiconn.isdir(path)) # Now we create the directory and verify that it exists self.luigiconn.mkdir(path) self.assertTrue(self.luigiconn.exists(path)) self.assertTrue(self.luigiconn.isdir(path)) # Now we remote the directory and verify that it no longer exists self.luigiconn.remove(path) self.assertFalse(self.luigiconn.exists(path)) self.assertFalse(self.luigiconn.isdir(path)) def test_lifecycle_of_dirpath(self): self.aux_lifecycle_of_directory(DROPBOX_TEST_SIMPLE_DIR_TO_CREATE_AND_DELETE) def test_lifecycle_of_dirpath_with_trailing_slash(self): self.aux_lifecycle_of_directory(DROPBOX_TEST_SIMPLE_DIR_TO_CREATE_AND_DELETE + "/") def test_lifecycle_of_dirpath_with_several_trailing_mixed(self): self.luigiconn.mkdir(DROPBOX_TEST_SIMPLE_DIR_TO_CREATE_AND_DELETE + "/") self.assertTrue(self.luigiconn.exists(DROPBOX_TEST_SIMPLE_DIR_TO_CREATE_AND_DELETE)) self.luigiconn.remove(DROPBOX_TEST_SIMPLE_DIR_TO_CREATE_AND_DELETE) self.assertFalse(self.luigiconn.exists(DROPBOX_TEST_SIMPLE_DIR_TO_CREATE_AND_DELETE + "/")) def test_lifecycle_of_dirpath_with_several_trailing_mixed_2(self): self.luigiconn.mkdir(DROPBOX_TEST_SIMPLE_DIR_TO_CREATE_AND_DELETE) self.assertTrue(self.luigiconn.exists(DROPBOX_TEST_SIMPLE_DIR_TO_CREATE_AND_DELETE + "/")) self.luigiconn.remove(DROPBOX_TEST_SIMPLE_DIR_TO_CREATE_AND_DELETE + "/") self.assertFalse(self.luigiconn.exists(DROPBOX_TEST_SIMPLE_DIR_TO_CREATE_AND_DELETE)) def test_mkdir_new_dir_two_slashes(self): with self.assertRaises(dropbox.dropbox_client.ApiError): self.luigiconn.mkdir(DROPBOX_TEST_SIMPLE_DIR_TO_CREATE_AND_DELETE + "//") def test_mkdir_recreate_dir(self): try: self.luigiconn.mkdir(DROPBOX_TEST_SIMPLE_DIR) except Exception as ex: self.fail("mkdir with default options raises Exception:" + str(ex)) try: self.luigiconn.mkdir(DROPBOX_TEST_SIMPLE_DIR, raise_if_exists=False) except Exception as ex: self.fail("mkdir with 'raise_if_exists=False' raises Exception:" + str(ex)) with self.assertRaises(luigi.target.FileAlreadyExists): self.luigiconn.mkdir(DROPBOX_TEST_SIMPLE_DIR, raise_if_exists=True) def test_mkdir_recreate_slashed_dir(self): try: self.luigiconn.mkdir(DROPBOX_TEST_SIMPLE_DIR + "/") except Exception as ex: self.fail("mkdir with default options raises Exception:" + str(ex)) try: self.luigiconn.mkdir(DROPBOX_TEST_SIMPLE_DIR + "/", raise_if_exists=False) except Exception as ex: self.fail("mkdir with 'raise_if_exists=False' raises Exception:" + str(ex)) with self.assertRaises(luigi.target.FileAlreadyExists): self.luigiconn.mkdir(DROPBOX_TEST_SIMPLE_DIR + "/", raise_if_exists=True) def test_mkdir_recreate_file(self): with self.assertRaises(luigi.target.NotADirectory): self.luigiconn.mkdir(DROPBOX_TEST_SIMPLE_FILE) with self.assertRaises(luigi.target.NotADirectory): self.luigiconn.mkdir(DROPBOX_TEST_SIMPLE_FILE, raise_if_exists=True) with self.assertRaises(luigi.target.NotADirectory): self.luigiconn.mkdir(DROPBOX_TEST_SIMPLE_FILE, raise_if_exists=False) def test_isdir(self): self.assertTrue(self.luigiconn.isdir("/")) self.assertTrue(self.luigiconn.isdir(DROPBOX_TEST_PATH)) self.assertTrue(self.luigiconn.isdir(DROPBOX_TEST_SIMPLE_DIR)) self.assertTrue(self.luigiconn.isdir(DROPBOX_TEST_SIMPLE_DIR + "/")) self.assertFalse(self.luigiconn.isdir(DROPBOX_TEST_SIMPLE_FILE)) self.assertFalse(self.luigiconn.isdir(DROPBOX_TEST_NON_EXISTING_FILE)) self.assertFalse(self.luigiconn.isdir(DROPBOX_TEST_NON_EXISTING_FILE + "/")) def test_move(self): md, res = self.dropbox_api.files_download(DROPBOX_TEST_FILE_TO_MOVE_ORIG) initial_contents = res.content self.luigiconn.move(DROPBOX_TEST_FILE_TO_MOVE_ORIG, DROPBOX_TEST_FILE_TO_MOVE_DEST) md, res = self.dropbox_api.files_download(DROPBOX_TEST_FILE_TO_MOVE_DEST) after_moving_contents = res.content self.assertEqual(initial_contents, after_moving_contents) self.assertFalse(self.luigiconn.exists(DROPBOX_TEST_FILE_TO_MOVE_ORIG)) self.assertTrue(self.luigiconn.exists(DROPBOX_TEST_FILE_TO_MOVE_DEST)) def test_copy(self): md, res = self.dropbox_api.files_download(DROPBOX_TEST_FILE_TO_COPY_ORIG) initial_contents = res.content self.luigiconn.copy(DROPBOX_TEST_FILE_TO_COPY_ORIG, DROPBOX_TEST_FILE_TO_COPY_DEST) md, res = self.dropbox_api.files_download(DROPBOX_TEST_FILE_TO_COPY_DEST) after_copyng_contents = res.content self.assertEqual(initial_contents, after_copyng_contents) self.assertTrue(self.luigiconn.exists(DROPBOX_TEST_FILE_TO_COPY_ORIG)) self.assertTrue(self.luigiconn.exists(DROPBOX_TEST_FILE_TO_COPY_DEST)) @pytest.mark.dropbox class TestDropboxTarget(unittest.TestCase): def setUp(self): self.luigiconn = DropboxClient(DROPBOX_APP_TOKEN) self.dropbox_api = dropbox.dropbox_client.Dropbox(DROPBOX_APP_TOKEN) self.initial_contents = b"\x00hello\xff\x00-\xe2\x82\x28" # Binary invalid-utf8 sequence self.dropbox_api.files_upload(self.initial_contents, DROPBOX_TEST_SIMPLE_FILE) def tearDown(self): self.dropbox_api.files_delete_v2(DROPBOX_TEST_PATH) self.dropbox_api._session.close() def test_download_from_dropboxtarget_to_local(self): class Download(luigi.ExternalTask): dbx_path = luigi.Parameter() def output(self): return luigi.contrib.dropbox.DropboxTarget(self.dbx_path, DROPBOX_APP_TOKEN, format=NopFormat()) class DbxToLocalTask(luigi.Task): local_path = luigi.Parameter() dbx_path = luigi.Parameter() def requires(self): return Download(dbx_path=self.dbx_path) def output(self): return luigi.LocalTarget(path=self.local_path, format=NopFormat()) def run(self): with self.input().open("r") as dbxfile, self.output().open("w") as localfile: remote_contents = dbxfile.read() localfile.write(remote_contents * 3) tmp_file = tempfile.mkdtemp() + os.sep + "tmp.file" luigi.build([DbxToLocalTask(dbx_path=DROPBOX_TEST_SIMPLE_FILE, local_path=tmp_file)], local_scheduler=True) expected_contents = self.initial_contents * 3 with open(tmp_file, "rb") as f: actual_contents = f.read() self.assertEqual(expected_contents, actual_contents) def test_write_small_text_file_to_dropbox(self): small_input_text = "The greatest glory in living lies not in never falling\nbut in rising every time we fall." class WriteToDrobopxTest(luigi.Task): def output(self): return luigi.contrib.dropbox.DropboxTarget(DROPBOX_TEST_FILE_TO_UPLOAD_TEXT, DROPBOX_APP_TOKEN) def run(self): with self.output().open("w") as dbxfile: dbxfile.write(small_input_text) luigi.build([WriteToDrobopxTest()], local_scheduler=True) actual_content = self.dropbox_api.files_download(DROPBOX_TEST_FILE_TO_UPLOAD_TEXT)[1].content self.assertEqual(actual_content.decode(), small_input_text) def aux_write_binary_file_to_dropbox(self, multiplier): large_contents = b"X\n\xe2\x28\xa1" * multiplier output_file = DROPBOX_TEST_FILE_TO_UPLOAD_LARGE + str(multiplier) class WriteToDrobopxTest(luigi.Task): def output(self): return luigi.contrib.dropbox.DropboxTarget(output_file, DROPBOX_APP_TOKEN, format=luigi.format.Nop) def run(self): with self.output().open("w") as dbxfile: dbxfile.write(large_contents) luigi.build([WriteToDrobopxTest()], local_scheduler=True) actual_content = self.dropbox_api.files_download(output_file)[1].content self.assertEqual(actual_content, large_contents) def test_write_small_binary_file_to_dropbox(self): self.aux_write_binary_file_to_dropbox(1024) def test_write_medium_binary_file_to_dropbox(self): self.aux_write_binary_file_to_dropbox(1024 * 1024) def test_write_large_binary_file_to_dropbox(self): self.aux_write_binary_file_to_dropbox(3 * 1024 * 1024) def test_write_using_nondefault_format(self): contents = b"X\n\xe2\x28\xa1" class WriteToDrobopxTest(luigi.Task): def output(self): return luigi.contrib.dropbox.DropboxTarget(DROPBOX_TEST_FILE_TO_UPLOAD_BZIP2, DROPBOX_APP_TOKEN, format=luigi.format.Bzip2) def run(self): with self.output().open("w") as bzip2_dbxfile: bzip2_dbxfile.write(contents) luigi.build([WriteToDrobopxTest()], local_scheduler=True) remote_content = self.dropbox_api.files_download(DROPBOX_TEST_FILE_TO_UPLOAD_BZIP2)[1].content self.assertEqual(contents, bz2.decompress(remote_content)) def test_write_using_a_temporary_path(self): contents = b"X\n\xe2\x28\xa1" class WriteToDrobopxTest(luigi.Task): def output(self): return luigi.contrib.dropbox.DropboxTarget(DROPBOX_TEST_FILE_TO_UPLOAD_BIN, DROPBOX_APP_TOKEN) def run(self): with self.output().temporary_path() as tmp_path: open(tmp_path, "wb").write(contents) luigi.build([WriteToDrobopxTest()], local_scheduler=True) actual_content = self.dropbox_api.files_download(DROPBOX_TEST_FILE_TO_UPLOAD_BIN)[1].content self.assertEqual(actual_content, contents) ================================================ FILE: test/contrib/ecs_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2015 Outlier Bio, LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ Integration test for the Luigi wrapper of EC2 Container Service (ECSTask) Requires: - boto3 package - Amazon AWS credentials discoverable by boto3 (e.g., by using ``aws configure`` from awscli_) - A running ECS cluster (see `ECS Get Started`_) Written and maintained by Jake Feala (@jfeala) for Outlier Bio (@outlierbio) .. _awscli: https://aws.amazon.com/cli .. _`ECS Get Started`: http://docs.aws.amazon.com/AmazonECS/latest/developerguide/ECS_GetStarted.html """ import unittest import pytest from moto import mock_ecs import luigi from luigi.contrib.ecs import ECSTask, _get_task_statuses try: import boto3 except ImportError: raise unittest.SkipTest("boto3 is not installed. ECSTasks require boto3") TEST_TASK_DEF = { "family": "hello-world", "volumes": [], "containerDefinitions": [ {"memory": 1, "essential": True, "name": "hello-world", "image": "ubuntu", "command": ["/bin/echo", "hello world"]}, {"memory": 1, "essential": True, "name": "hello-world-2", "image": "ubuntu", "command": ["/bin/echo", "hello world #2!"]}, ], } class ECSTaskNoOutput(ECSTask): def complete(self): if self.ecs_task_ids: return all([status == "STOPPED" for status in _get_task_statuses(self.ecs_task_ids)]) return False class ECSTaskOverrideCommand(ECSTaskNoOutput): @property def command(self): return [{"name": "hello-world", "command": ["/bin/sleep", "10"]}] class ECSTaskCustomRunTaskKwargs(ECSTaskNoOutput): @property def run_task_kwargs(self): return {"overrides": {"ephemeralStorage": {"sizeInGiB": 30}}} class ECSTaskCustomRunTaskKwargsWithCollidingCommand(ECSTaskNoOutput): @property def command(self): return [ {"name": "hello-world", "command": ["/bin/sleep", "10"]}, {"name": "hello-world-2", "command": ["/bin/sleep", "10"]}, ] @property def run_task_kwargs(self): return { "launchType": "FARGATE", "platformVersion": "1.4.0", "networkConfiguration": { "awsvpcConfiguration": { "subnets": ["subnet-01234567890abcdef", "subnet-abcdef01234567890"], "securityGroups": [ "sg-abcdef01234567890", ], "assignPublicIp": "ENABLED", } }, "overrides": {"containerOverrides": [{"name": "hello-world-2", "command": ["command-to-be-overwritten"]}], "ephemeralStorage": {"sizeInGiB": 30}}, } class ECSTaskCustomRunTaskKwargsWithMergedCommands(ECSTaskNoOutput): @property def command(self): return [{"name": "hello-world", "command": ["/bin/sleep", "10"]}] @property def run_task_kwargs(self): return { "launchType": "FARGATE", "platformVersion": "1.4.0", "networkConfiguration": { "awsvpcConfiguration": { "subnets": ["subnet-01234567890abcdef", "subnet-abcdef01234567890"], "securityGroups": [ "sg-abcdef01234567890", ], "assignPublicIp": "ENABLED", } }, "overrides": {"containerOverrides": [{"name": "hello-world-2", "command": ["/bin/sleep", "10"]}], "ephemeralStorage": {"sizeInGiB": 30}}, } @pytest.mark.aws class TestECSTask(unittest.TestCase): @mock_ecs def setUp(self): # Register the test task definition response = boto3.client("ecs").register_task_definition(**TEST_TASK_DEF) self.arn = response["taskDefinition"]["taskDefinitionArn"] @mock_ecs def test_unregistered_task(self): t = ECSTaskNoOutput(task_def=TEST_TASK_DEF) luigi.build([t], local_scheduler=True) @mock_ecs def test_registered_task(self): t = ECSTaskNoOutput(task_def_arn=self.arn) luigi.build([t], local_scheduler=True) @mock_ecs def test_override_command(self): t = ECSTaskOverrideCommand(task_def_arn=self.arn) luigi.build([t], local_scheduler=True) @mock_ecs def test_custom_run_task_kwargs(self): t = ECSTaskCustomRunTaskKwargs(task_def_arn=self.arn) self.assertEqual(t.combined_overrides, {"ephemeralStorage": {"sizeInGiB": 30}}) luigi.build([t], local_scheduler=True) @mock_ecs def test_custom_run_task_kwargs_with_colliding_command(self): t = ECSTaskCustomRunTaskKwargsWithCollidingCommand(task_def_arn=self.arn) combined_overrides = t.combined_overrides self.assertEqual( sorted(combined_overrides["containerOverrides"], key=lambda x: x["name"]), sorted( [ {"name": "hello-world", "command": ["/bin/sleep", "10"]}, {"name": "hello-world-2", "command": ["/bin/sleep", "10"]}, ], key=lambda x: x["name"], ), ) self.assertEqual(combined_overrides["ephemeralStorage"], {"sizeInGiB": 30}) luigi.build([t], local_scheduler=True) @mock_ecs def test_custom_run_task_kwargs_with_merged_commands(self): t = ECSTaskCustomRunTaskKwargsWithMergedCommands(task_def_arn=self.arn) combined_overrides = t.combined_overrides self.assertEqual( sorted(combined_overrides["containerOverrides"], key=lambda x: x["name"]), sorted( [ {"name": "hello-world", "command": ["/bin/sleep", "10"]}, {"name": "hello-world-2", "command": ["/bin/sleep", "10"]}, ], key=lambda x: x["name"], ), ) self.assertEqual(combined_overrides["ephemeralStorage"], {"sizeInGiB": 30}) luigi.build([t], local_scheduler=True) ================================================ FILE: test/contrib/esindex_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ Tests for Elasticsearch index (esindex) target and indexing. An Elasticsearch server must be running for these tests. To use a non-standard host and port, use `ESINDEX_TEST_HOST`, `ESINDEX_TEST_PORT` environment variables to override defaults. To test HTTP basic authentication `ESINDEX_TEST_HTTP_AUTH`. Example running tests against port 9201 with basic auth: $ ESINDEX_TEST_PORT=9201 ESINDEX_TEST_HTTP_AUTH='admin:admin' nosetests test/_esindex_test.py """ # pylint: disable=C0103,E1101,F0401 import collections import datetime import os import elasticsearch import pytest from elasticsearch.connection import Urllib3HttpConnection from helpers import unittest import luigi from luigi.contrib.esindex import CopyToIndex, ElasticsearchTarget HOST = os.getenv("ESINDEX_TEST_HOST", "localhost") PORT = os.getenv("ESINDEX_TEST_PORT", 9200) HTTP_AUTH = os.getenv("ESINDEX_TEST_HTTP_AUTH", None) INDEX = "esindex_luigi_test" DOC_TYPE = "esindex_test_type" MARKER_INDEX = "esindex_luigi_test_index_updates" MARKER_DOC_TYPE = "esindex_test_entry" def _create_test_index(): """Create content index, if if does not exists.""" es = elasticsearch.Elasticsearch(connection_class=Urllib3HttpConnection, host=HOST, port=PORT, http_auth=HTTP_AUTH) if not es.indices.exists(INDEX): es.indices.create(INDEX) try: _create_test_index() except Exception: raise unittest.SkipTest("Unable to connect to ElasticSearch") @pytest.mark.aws class ElasticsearchTargetTest(unittest.TestCase): """Test touch and exists.""" def test_touch_and_exists(self): """Basic test.""" target = ElasticsearchTarget(HOST, PORT, INDEX, DOC_TYPE, "update_id", http_auth=HTTP_AUTH) target.marker_index = MARKER_INDEX target.marker_doc_type = MARKER_DOC_TYPE delete() self.assertFalse(target.exists(), "Target should not exist before touching it") target.touch() self.assertTrue(target.exists(), "Target should exist after touching it") delete() def delete(): """Delete marker_index, if it exists.""" es = elasticsearch.Elasticsearch(connection_class=Urllib3HttpConnection, host=HOST, port=PORT, http_auth=HTTP_AUTH) if es.indices.exists(MARKER_INDEX): es.indices.delete(MARKER_INDEX) es.indices.refresh() class CopyToTestIndex(CopyToIndex): """Override the default `marker_index` table with a test name.""" host = HOST port = PORT http_auth = HTTP_AUTH index = INDEX doc_type = DOC_TYPE marker_index_hist_size = 0 def output(self): """Use a test target with an own marker_index.""" target = ElasticsearchTarget( host=self.host, port=self.port, http_auth=self.http_auth, index=self.index, doc_type=self.doc_type, update_id=self.update_id(), marker_index_hist_size=self.marker_index_hist_size, ) target.marker_index = MARKER_INDEX target.marker_doc_type = MARKER_DOC_TYPE return target class IndexingTask1(CopyToTestIndex): """Test the redundant version, where `_index` and `_type` are given in the `docs` as well. A more DRY example is `IndexingTask2`.""" def docs(self): """Return a list with a single doc.""" return [{"_id": 123, "_index": self.index, "_type": self.doc_type, "name": "sample", "date": "today"}] class IndexingTask2(CopyToTestIndex): """Just another task.""" def docs(self): """Return a list with a single doc.""" return [{"_id": 234, "_index": self.index, "_type": self.doc_type, "name": "another", "date": "today"}] class IndexingTask3(CopyToTestIndex): """This task will request an empty index to start with.""" purge_existing_index = True def docs(self): """Return a list with a single doc.""" return [{"_id": 234, "_index": self.index, "_type": self.doc_type, "name": "yet another", "date": "today"}] def _cleanup(): """Delete both the test marker index and the content index.""" es = elasticsearch.Elasticsearch(connection_class=Urllib3HttpConnection, host=HOST, port=PORT, http_auth=HTTP_AUTH) if es.indices.exists(MARKER_INDEX): es.indices.delete(MARKER_INDEX) if es.indices.exists(INDEX): es.indices.delete(INDEX) @pytest.mark.aws class CopyToIndexTest(unittest.TestCase): """Test indexing tasks.""" @classmethod def setUpClass(cls): cls.es = elasticsearch.Elasticsearch(connection_class=Urllib3HttpConnection, host=HOST, port=PORT, http_auth=HTTP_AUTH) def setUp(self): """Cleanup before each test.""" _cleanup() def tearDown(self): """Remove residues after each test.""" _cleanup() def test_copy_to_index(self): """Test a single document upload.""" task = IndexingTask1() self.assertFalse(self.es.indices.exists(task.index)) self.assertFalse(task.complete()) luigi.build([task], local_scheduler=True) self.assertTrue(self.es.indices.exists(task.index)) self.assertTrue(task.complete()) self.assertEqual(1, self.es.count(index=task.index).get("count")) self.assertEqual({"date": "today", "name": "sample"}, self.es.get_source(index=task.index, doc_type=task.doc_type, id=123)) def test_copy_to_index_incrementally(self): """Test two tasks that upload docs into the same index.""" task1 = IndexingTask1() task2 = IndexingTask2() self.assertFalse(self.es.indices.exists(task1.index)) self.assertFalse(self.es.indices.exists(task2.index)) self.assertFalse(task1.complete()) self.assertFalse(task2.complete()) luigi.build([task1, task2], local_scheduler=True) self.assertTrue(self.es.indices.exists(task1.index)) self.assertTrue(self.es.indices.exists(task2.index)) self.assertTrue(task1.complete()) self.assertTrue(task2.complete()) self.assertEqual(2, self.es.count(index=task1.index).get("count")) self.assertEqual(2, self.es.count(index=task2.index).get("count")) self.assertEqual({"date": "today", "name": "sample"}, self.es.get_source(index=task1.index, doc_type=task1.doc_type, id=123)) self.assertEqual({"date": "today", "name": "another"}, self.es.get_source(index=task2.index, doc_type=task2.doc_type, id=234)) def test_copy_to_index_purge_existing(self): """Test purge_existing_index purges index.""" task1 = IndexingTask1() task2 = IndexingTask2() task3 = IndexingTask3() luigi.build([task1, task2], local_scheduler=True) luigi.build([task3], local_scheduler=True) self.assertTrue(self.es.indices.exists(task3.index)) self.assertTrue(task3.complete()) self.assertEqual(1, self.es.count(index=task3.index).get("count")) self.assertEqual({"date": "today", "name": "yet another"}, self.es.get_source(index=task3.index, doc_type=task3.doc_type, id=234)) @pytest.mark.aws class MarkerIndexTest(unittest.TestCase): @classmethod def setUpClass(cls): cls.es = elasticsearch.Elasticsearch(connection_class=Urllib3HttpConnection, host=HOST, port=PORT, http_auth=HTTP_AUTH) def setUp(self): """Cleanup before each test.""" _cleanup() def tearDown(self): """Remove residues after each test.""" _cleanup() def test_update_marker(self): def will_raise(): self.es.count(index=MARKER_INDEX, doc_type=MARKER_DOC_TYPE, body={"query": {"match_all": {}}}) self.assertRaises(elasticsearch.NotFoundError, will_raise) task1 = IndexingTask1() luigi.build([task1], local_scheduler=True) result = self.es.count(index=MARKER_INDEX, doc_type=MARKER_DOC_TYPE, body={"query": {"match_all": {}}}) self.assertEqual(1, result.get("count")) result = self.es.search(index=MARKER_INDEX, doc_type=MARKER_DOC_TYPE, body={"query": {"match_all": {}}}) marker_doc = result.get("hits").get("hits")[0].get("_source") self.assertEqual(task1.task_id, marker_doc.get("update_id")) self.assertEqual(INDEX, marker_doc.get("target_index")) self.assertEqual(DOC_TYPE, marker_doc.get("target_doc_type")) self.assertTrue("date" in marker_doc) task2 = IndexingTask2() luigi.build([task2], local_scheduler=True) result = self.es.count(index=MARKER_INDEX, doc_type=MARKER_DOC_TYPE, body={"query": {"match_all": {}}}) self.assertEqual(2, result.get("count")) result = self.es.search(index=MARKER_INDEX, doc_type=MARKER_DOC_TYPE, body={"query": {"match_all": {}}}) hits = result.get("hits").get("hits") Entry = collections.namedtuple("Entry", ["date", "update_id"]) dates_update_id = [] for hit in hits: source = hit.get("_source") update_id = source.get("update_id") date = source.get("date") dates_update_id.append(Entry(date, update_id)) it = iter(sorted(dates_update_id)) first = next(it) second = next(it) self.assertTrue(first.date < second.date) self.assertEqual(first.update_id, task1.task_id) self.assertEqual(second.update_id, task2.task_id) class IndexingTask4(CopyToTestIndex): """Just another task.""" date = luigi.DateParameter(default=datetime.date(1970, 1, 1)) marker_index_hist_size = 1 def docs(self): """Return a list with a single doc.""" return [{"_id": 234, "_index": self.index, "_type": self.doc_type, "name": "another", "date": "today"}] @pytest.mark.aws class IndexHistSizeTest(unittest.TestCase): @classmethod def setUpClass(cls): cls.es = elasticsearch.Elasticsearch(connection_class=Urllib3HttpConnection, host=HOST, port=PORT, http_auth=HTTP_AUTH) def setUp(self): """Cleanup before each test.""" _cleanup() def tearDown(self): """Remove residues after each test.""" _cleanup() def test_limited_history(self): task4_1 = IndexingTask4(date=datetime.date(2000, 1, 1)) luigi.build([task4_1], local_scheduler=True) task4_2 = IndexingTask4(date=datetime.date(2001, 1, 1)) luigi.build([task4_2], local_scheduler=True) task4_3 = IndexingTask4(date=datetime.date(2002, 1, 1)) luigi.build([task4_3], local_scheduler=True) result = self.es.count(index=MARKER_INDEX, doc_type=MARKER_DOC_TYPE, body={"query": {"match_all": {}}}) self.assertEqual(1, result.get("count")) marker_index_document_id = task4_3.output().marker_index_document_id() result = self.es.get(id=marker_index_document_id, index=MARKER_INDEX, doc_type=MARKER_DOC_TYPE) self.assertEqual(task4_3.task_id, result.get("_source").get("update_id")) ================================================ FILE: test/contrib/external_daily_snapshot_test.py ================================================ # Copyright (c) 2013 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); you may not # use this file except in compliance with the License. You may obtain a copy of # the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations under # the License. import datetime import unittest import luigi from luigi.contrib.external_daily_snapshot import ExternalDailySnapshot from luigi.mock import MockTarget class DataDump(ExternalDailySnapshot): param = luigi.Parameter() a = luigi.Parameter(default="zebra") aa = luigi.Parameter(default="Congo") def output(self): return MockTarget("data-%s-%s-%s-%s" % (self.param, self.a, self.aa, self.date)) class ExternalDailySnapshotTest(unittest.TestCase): def test_latest(self): MockTarget("data-xyz-zebra-Congo-2012-01-01").open("w").close() d = DataDump.latest(date=datetime.date(2012, 1, 10), param="xyz") self.assertEqual(d.date, datetime.date(2012, 1, 1)) def test_latest_not_exists(self): MockTarget("data-abc-zebra-Congo-2012-01-01").open("w").close() d = DataDump.latest(date=datetime.date(2012, 1, 11), param="abc", lookback=5) self.assertEqual(d.date, datetime.date(2012, 1, 7)) def test_deterministic(self): MockTarget("data-pqr-zebra-Congo-2012-01-01").open("w").close() d = DataDump.latest(date=datetime.date(2012, 1, 10), param="pqr", a="zebra", aa="Congo") self.assertEqual(d.date, datetime.date(2012, 1, 1)) MockTarget("data-pqr-zebra-Congo-2012-01-05").open("w").close() d = DataDump.latest(date=datetime.date(2012, 1, 10), param="pqr", aa="Congo", a="zebra") self.assertEqual(d.date, datetime.date(2012, 1, 1)) # Should still be the same ================================================ FILE: test/contrib/external_program_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2016 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import os import shutil import subprocess import tempfile from functools import partial from io import BytesIO from multiprocessing import Value from subprocess import Popen import mock import pytest from helpers import unittest from mock import call, patch import luigi import luigi.contrib.hdfs from luigi.contrib.external_program import ExternalProgramRunError, ExternalProgramTask, ExternalPythonProgramTask def poll_generator(): yield None yield 1 def setup_run_process(proc): poll_gen = poll_generator() proc.return_value.poll = lambda: next(poll_gen) proc.return_value.returncode = 0 proc.return_value.stdout = BytesIO() proc.return_value.stderr = BytesIO() class TestExternalProgramTask(ExternalProgramTask): def program_args(self): return ["app_path", "arg1", "arg2"] def output(self): return luigi.LocalTarget("output") class TestLogStderrOnFailureOnlyTask(TestExternalProgramTask): always_log_stderr = False class TestTouchTask(ExternalProgramTask): file_path = luigi.Parameter() def program_args(self): return ["touch", self.output().path] def output(self): return luigi.LocalTarget(self.file_path) class TestEchoTask(ExternalProgramTask): MESSAGE = "Hello, world!" def program_args(self): return ["echo", self.MESSAGE] @pytest.mark.contrib class ExternalProgramTaskTest(unittest.TestCase): @patch("luigi.contrib.external_program.subprocess.Popen") def test_run(self, proc): setup_run_process(proc) job = TestExternalProgramTask() job.run() self.assertEqual(proc.call_args[0][0], ["app_path", "arg1", "arg2"]) @patch("luigi.contrib.external_program.logger") @patch("luigi.contrib.external_program.tempfile.TemporaryFile") @patch("luigi.contrib.external_program.subprocess.Popen") def test_handle_failed_job(self, proc, file, logger): proc.return_value.returncode = 1 file.return_value = BytesIO(b"stderr") try: job = TestExternalProgramTask() job.run() except ExternalProgramRunError as e: self.assertEqual(e.err, "stderr") self.assertIn("STDERR: stderr", str(e)) self.assertIn(call.info("Program stderr:\nstderr"), logger.mock_calls) else: self.fail("Should have thrown ExternalProgramRunError") @patch("luigi.contrib.external_program.logger") @patch("luigi.contrib.external_program.tempfile.TemporaryFile") @patch("luigi.contrib.external_program.subprocess.Popen") def test_always_log_stderr_on_failure(self, proc, file, logger): proc.return_value.returncode = 1 file.return_value = BytesIO(b"stderr") with self.assertRaises(ExternalProgramRunError): job = TestLogStderrOnFailureOnlyTask() job.run() self.assertIn(call.info("Program stderr:\nstderr"), logger.mock_calls) @patch("luigi.contrib.external_program.logger") @patch("luigi.contrib.external_program.tempfile.TemporaryFile") @patch("luigi.contrib.external_program.subprocess.Popen") def test_log_stderr_on_success_by_default(self, proc, file, logger): proc.return_value.returncode = 0 file.return_value = BytesIO(b"stderr") job = TestExternalProgramTask() job.run() self.assertIn(call.info("Program stderr:\nstderr"), logger.mock_calls) def test_capture_output_set_to_false_writes_output_to_stdout(self): out = tempfile.TemporaryFile() def Popen_wrap(args, **kwargs): kwargs.pop("stdout", None) return Popen(args, stdout=out, **kwargs) with mock.patch("luigi.contrib.external_program.subprocess.Popen", wraps=Popen_wrap): task = TestEchoTask(capture_output=False) task.run() stdout = task._clean_output_file(out).strip() self.assertEqual(stdout, task.MESSAGE) @patch("luigi.contrib.external_program.logger") @patch("luigi.contrib.external_program.tempfile.TemporaryFile") @patch("luigi.contrib.external_program.subprocess.Popen") def test_dont_log_stderr_on_success_if_disabled(self, proc, file, logger): proc.return_value.returncode = 0 file.return_value = BytesIO(b"stderr") job = TestLogStderrOnFailureOnlyTask() job.run() self.assertNotIn(call.info("Program stderr:\nstderr"), logger.mock_calls) @patch("luigi.contrib.external_program.subprocess.Popen") def test_program_args_must_be_implemented(self, proc): with self.assertRaises(NotImplementedError): job = ExternalProgramTask() job.run() @patch("luigi.contrib.external_program.subprocess.Popen") def test_app_interruption(self, proc): def interrupt(): raise KeyboardInterrupt() proc.return_value.wait = interrupt try: job = TestExternalProgramTask() job.run() except KeyboardInterrupt: pass proc.return_value.kill.check_called() def test_non_mocked_task_run(self): # create a tempdir first, to ensure an empty playground for # TestTouchTask to create its file in tempdir = tempfile.mkdtemp() tempfile_path = os.path.join(tempdir, "testfile") try: job = TestTouchTask(file_path=tempfile_path) job.run() self.assertTrue(luigi.LocalTarget(tempfile_path).exists()) finally: # clean up temp files even if assertion fails shutil.rmtree(tempdir) def test_tracking_url_pattern_works_with_capture_output_disabled(self): test_val = Value("i", 0) def fake_set_tracking_url(val, url): if url == "TEXT": val.value += 1 task = TestEchoTask(capture_output=False, stream_for_searching_tracking_url="stdout", tracking_url_pattern=r"SOME (.*)") task.MESSAGE = "SOME TEXT" with mock.patch.object(task, "set_tracking_url", new=partial(fake_set_tracking_url, test_val)): task.run() self.assertEqual(test_val.value, 1) def test_tracking_url_pattern_works_with_capture_output_enabled(self): test_val = Value("i", 0) def fake_set_tracking_url(val, url): if url == "THING": val.value += 1 task = TestEchoTask(capture_output=True, stream_for_searching_tracking_url="stdout", tracking_url_pattern=r"ANY(.*)") task.MESSAGE = "ANYTHING" with mock.patch.object(task, "set_tracking_url", new=partial(fake_set_tracking_url, test_val)): task.run() self.assertEqual(test_val.value, 1) def test_tracking_url_pattern_works_with_stderr(self): test_val = Value("i", 0) def fake_set_tracking_url(val, url): if url == "THING_ELSE": val.value += 1 def Popen_wrap(args, **kwargs): return Popen('>&2 echo "ANYTHING_ELSE"', shell=True, **kwargs) task = TestEchoTask(capture_output=True, stream_for_searching_tracking_url="stderr", tracking_url_pattern=r"ANY(.*)") with mock.patch("luigi.contrib.external_program.subprocess.Popen", wraps=Popen_wrap): with mock.patch.object(task, "set_tracking_url", new=partial(fake_set_tracking_url, test_val)): task.run() self.assertEqual(test_val.value, 1) def test_no_url_searching_is_performed_if_pattern_is_not_set(self): def Popen_wrap(args, **kwargs): # stdout should not be replaced with pipe if tracking_url_pattern is not set self.assertNotEqual(kwargs["stdout"], subprocess.PIPE) return Popen(args, **kwargs) task = TestEchoTask(capture_output=True, stream_for_searching_tracking_url="stdout") with mock.patch("luigi.contrib.external_program.subprocess.Popen", wraps=Popen_wrap): task.run() def test_tracking_url_context_works_without_capture_output(self): test_val = Value("i", 0) def fake_set_tracking_url(val, url): if url == "world": val.value += 1 task = TestEchoTask(capture_output=False, stream_for_searching_tracking_url="stdout", tracking_url_pattern=r"Hello, (.*)!") test_args = list(map(str, task.program_args())) with mock.patch.object(task, "set_tracking_url", new=partial(fake_set_tracking_url, test_val)): with task._proc_with_tracking_url_context(proc_args=test_args, proc_kwargs={}) as proc: proc.wait() self.assertEqual(test_val.value, 1) def test_tracking_url_context_works_correctly_when_logs_output_pattern_to_url_is_not_default(self): class _Task(TestEchoTask): def build_tracking_url(self, logs_output): return "The {} is mine".format(logs_output) test_val = Value("i", 0) def fake_set_tracking_url(val, url): if url == "The world is mine": val.value += 1 task = _Task(capture_output=False, stream_for_searching_tracking_url="stdout", tracking_url_pattern=r"Hello, (.*)!") test_args = list(map(str, task.program_args())) with mock.patch.object(task, "set_tracking_url", new=partial(fake_set_tracking_url, test_val)): with task._proc_with_tracking_url_context(proc_args=test_args, proc_kwargs={}) as proc: proc.wait() self.assertEqual(test_val.value, 1) class TestExternalPythonProgramTask(ExternalPythonProgramTask): virtualenv = "/path/to/venv" extra_pythonpath = "/extra/pythonpath" def program_args(self): return ["app_path", "arg1", "arg2"] def output(self): return luigi.LocalTarget("output") @pytest.mark.contrib class ExternalPythonProgramTaskTest(unittest.TestCase): @patch.dict("os.environ", {"OTHERVAR": "otherval"}, clear=True) @patch("luigi.contrib.external_program.subprocess.Popen") def test_original_environment_is_kept_intact(self, proc): setup_run_process(proc) job = TestExternalPythonProgramTask() job.run() proc_env = proc.call_args[1]["env"] self.assertIn("PYTHONPATH", proc_env) self.assertIn("OTHERVAR", proc_env) @patch.dict("os.environ", {"PATH": "/base/path"}, clear=True) @patch("luigi.contrib.external_program.subprocess.Popen") def test_venv_is_set_and_prepended_to_path(self, proc): setup_run_process(proc) job = TestExternalPythonProgramTask() job.run() proc_env = proc.call_args[1]["env"] self.assertIn("PATH", proc_env) self.assertTrue(proc_env["PATH"].startswith("/path/to/venv/bin")) self.assertTrue(proc_env["PATH"].endswith("/base/path")) self.assertIn("VIRTUAL_ENV", proc_env) self.assertEqual(proc_env["VIRTUAL_ENV"], "/path/to/venv") @patch.dict("os.environ", {}, clear=True) @patch("luigi.contrib.external_program.subprocess.Popen") def test_pythonpath_is_set_if_empty(self, proc): setup_run_process(proc) job = TestExternalPythonProgramTask() job.run() proc_env = proc.call_args[1]["env"] self.assertIn("PYTHONPATH", proc_env) self.assertTrue(proc_env["PYTHONPATH"].startswith("/extra/pythonpath")) @patch.dict("os.environ", {"PYTHONPATH": "/base/pythonpath"}, clear=True) @patch("luigi.contrib.external_program.subprocess.Popen") def test_pythonpath_is_prepended_if_not_empty(self, proc): setup_run_process(proc) job = TestExternalPythonProgramTask() job.run() proc_env = proc.call_args[1]["env"] self.assertIn("PYTHONPATH", proc_env) self.assertTrue(proc_env["PYTHONPATH"].startswith("/extra/pythonpath")) self.assertTrue(proc_env["PYTHONPATH"].endswith("/base/pythonpath")) ================================================ FILE: test/contrib/gcs_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2015 Twitter Inc # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """This is an integration test for the GCS-luigi binding. This test requires credentials that can access GCS & access to a bucket below. Follow the directions in the gcloud tools to set up local credentials. """ from helpers import unittest try: import google.auth import googleapiclient.errors except ImportError: raise unittest.SkipTest("Unable to load googleapiclient module") import os import tempfile import unittest from unittest import mock import pytest from target_test import FileSystemTargetTestMixin from luigi.contrib import gcs # In order to run this test, you should set these to your GCS project/bucket. # Unfortunately there's no mock PROJECT_ID = os.environ.get("GCS_TEST_PROJECT_ID", "your_project_id_here") BUCKET_NAME = os.environ.get("GCS_TEST_BUCKET", "your_test_bucket_here") TEST_FOLDER = os.environ.get("TRAVIS_BUILD_ID", "gcs_test_folder") CREDENTIALS, _ = google.auth.default() ATTEMPTED_BUCKET_CREATE = False def bucket_url(suffix): """ Actually it's bucket + test folder name """ return "gs://{}/{}/{}".format(BUCKET_NAME, TEST_FOLDER, suffix) class _GCSBaseTestCase(unittest.TestCase): def setUp(self): self.client = gcs.GCSClient(CREDENTIALS) global ATTEMPTED_BUCKET_CREATE if not ATTEMPTED_BUCKET_CREATE: try: self.client.client.buckets().insert(project=PROJECT_ID, body={"name": BUCKET_NAME}).execute() except googleapiclient.errors.HttpError as ex: if ex.resp.status != 409: # bucket already exists raise ATTEMPTED_BUCKET_CREATE = True self.client.remove(bucket_url(""), recursive=True) self.client.mkdir(bucket_url("")) def tearDown(self): self.client.remove(bucket_url(""), recursive=True) @pytest.mark.gcloud class GCSClientTest(_GCSBaseTestCase): def test_not_exists(self): self.assertFalse(self.client.exists(bucket_url("does_not_exist"))) self.assertFalse(self.client.isdir(bucket_url("does_not_exist"))) def test_exists(self): self.client.put_string("hello", bucket_url("exists_test")) self.assertTrue(self.client.exists(bucket_url("exists_test"))) self.assertFalse(self.client.isdir(bucket_url("exists_test"))) def test_mkdir(self): self.client.mkdir(bucket_url("exists_dir_test")) self.assertTrue(self.client.exists(bucket_url("exists_dir_test"))) self.assertTrue(self.client.isdir(bucket_url("exists_dir_test"))) def test_mkdir_by_upload(self): self.client.put_string("hello", bucket_url("test_dir_recursive/yep/file")) self.assertTrue(self.client.exists(bucket_url("test_dir_recursive"))) self.assertTrue(self.client.isdir(bucket_url("test_dir_recursive"))) def test_download(self): self.client.put_string("hello", bucket_url("test_download")) fp = self.client.download(bucket_url("test_download")) self.assertEqual(b"hello", fp.read()) def test_rename(self): self.client.put_string("hello", bucket_url("test_rename_1")) self.client.rename(bucket_url("test_rename_1"), bucket_url("test_rename_2")) self.assertFalse(self.client.exists(bucket_url("test_rename_1"))) self.assertTrue(self.client.exists(bucket_url("test_rename_2"))) def test_rename_recursive(self): self.client.mkdir(bucket_url("test_rename_recursive")) self.client.put_string("hello", bucket_url("test_rename_recursive/1")) self.client.put_string("hello", bucket_url("test_rename_recursive/2")) self.client.rename(bucket_url("test_rename_recursive"), bucket_url("test_rename_recursive_dest")) self.assertFalse(self.client.exists(bucket_url("test_rename_recursive"))) self.assertFalse(self.client.exists(bucket_url("test_rename_recursive/1"))) self.assertTrue(self.client.exists(bucket_url("test_rename_recursive_dest"))) self.assertTrue(self.client.exists(bucket_url("test_rename_recursive_dest/1"))) def test_remove(self): self.client.put_string("hello", bucket_url("test_remove")) self.client.remove(bucket_url("test_remove")) self.assertFalse(self.client.exists(bucket_url("test_remove"))) def test_remove_recursive(self): self.client.mkdir(bucket_url("test_remove_recursive")) self.client.put_string("hello", bucket_url("test_remove_recursive/1")) self.client.put_string("hello", bucket_url("test_remove_recursive/2")) self.client.remove(bucket_url("test_remove_recursive")) self.assertFalse(self.client.exists(bucket_url("test_remove_recursive"))) self.assertFalse(self.client.exists(bucket_url("test_remove_recursive/1"))) self.assertFalse(self.client.exists(bucket_url("test_remove_recursive/2"))) def test_listdir(self): self.client.put_string("hello", bucket_url("test_listdir/1")) self.client.put_string("hello", bucket_url("test_listdir/2")) self.assertEqual([bucket_url("test_listdir/1"), bucket_url("test_listdir/2")], list(self.client.listdir(bucket_url("test_listdir/")))) self.assertEqual([bucket_url("test_listdir/1"), bucket_url("test_listdir/2")], list(self.client.listdir(bucket_url("test_listdir")))) def test_put_file(self): with tempfile.NamedTemporaryFile() as fp: lorem = b"Lorem ipsum dolor sit amet, consectetuer adipiscing elit, sed diam nonummy nibh euismod tincidunt\n" # Larger file than chunk size, fails with incorrect progress set up big = lorem * 41943 fp.write(big) fp.flush() self.client.put(fp.name, bucket_url("test_put_file")) self.assertTrue(self.client.exists(bucket_url("test_put_file"))) self.assertEqual(big, self.client.download(bucket_url("test_put_file")).read()) def test_put_file_multiproc(self): temporary_fps = [] for _ in range(2): fp = tempfile.NamedTemporaryFile(mode="wb") lorem = b"Lorem ipsum dolor sit amet, consectetuer adipiscing elit, sed diam nonummy nibh euismod tincidunt\n" # Larger file than chunk size, fails with incorrect progress set up big = lorem * 41943 fp.write(big) fp.flush() temporary_fps.append(fp) filepaths = [f.name for f in temporary_fps] self.client.put_multiple(filepaths, bucket_url(""), num_process=2) for fp in temporary_fps: basename = os.path.basename(fp.name) self.assertTrue(self.client.exists(bucket_url(basename))) self.assertEqual(big, self.client.download(bucket_url(basename)).read()) fp.close() @pytest.mark.gcloud class GCSTargetTest(_GCSBaseTestCase, FileSystemTargetTestMixin): def create_target(self, format=None): return gcs.GCSTarget(bucket_url(self.id()), format=format, client=self.client) def test_close_twice(self): # Ensure gcs._DeleteOnCloseFile().close() can be called multiple times tgt = self.create_target() with tgt.open("w") as dst: dst.write("data") assert dst.closed dst.close() assert dst.closed with tgt.open() as src: assert src.read().strip() == "data" assert src.closed src.close() assert src.closed class RetryTest(unittest.TestCase): def test_success_with_retryable_error(self): m = mock.MagicMock(side_effect=[IOError, IOError, "test_func_output"]) @gcs.gcs_retry def mock_func(): return m() actual = mock_func() expected = "test_func_output" self.assertEqual(expected, actual) def test_fail_with_retry_limit_exceed(self): m = mock.MagicMock(side_effect=[IOError, IOError, IOError, IOError, IOError]) @gcs.gcs_retry def mock_func(): return m() with self.assertRaises(IOError): mock_func() ================================================ FILE: test/contrib/hadoop_jar_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import shlex import tempfile import pytest from helpers import unittest from mock import Mock, patch import luigi from luigi.contrib.hadoop_jar import HadoopJarJobError, HadoopJarJobTask, fix_paths class TestHadoopJarJob(HadoopJarJobTask): path = luigi.Parameter() def jar(self): return self.path class TestMissingJarJob(HadoopJarJobTask): pass class TestRemoteHadoopJarJob(TestHadoopJarJob): def ssh(self): return {"host": "myhost", "key_file": "file", "username": "user"} class TestRemoteMissingJarJob(TestHadoopJarJob): def ssh(self): return {"host": "myhost", "key_file": "file"} class TestRemoteHadoopJarTwoParamJob(TestRemoteHadoopJarJob): param2 = luigi.Parameter() @pytest.mark.apache class FixPathsTest(unittest.TestCase): def test_fix_paths_non_hdfs_target_path(self): mock_job = Mock() mock_arg = Mock() mock_job.args.return_value = [mock_arg] mock_arg.path = "right_path" self.assertEqual(([], ["right_path"]), fix_paths(mock_job)) def test_fix_paths_non_hdfs_target_str(self): mock_job = Mock() mock_arg = Mock(spec=[]) mock_job.args.return_value = [mock_arg] self.assertEqual(([], [str(mock_arg)]), fix_paths(mock_job)) class HadoopJarJobTaskTest(unittest.TestCase): @patch("luigi.contrib.hadoop.run_and_track_hadoop_job") def test_good(self, mock_job): mock_job.return_value = None with tempfile.NamedTemporaryFile() as temp_file: task = TestHadoopJarJob(temp_file.name) task.run() @patch("luigi.contrib.hadoop.run_and_track_hadoop_job") def test_missing_jar(self, mock_job): mock_job.return_value = None task = TestMissingJarJob() self.assertRaises(HadoopJarJobError, task.run) @patch("luigi.contrib.hadoop.run_and_track_hadoop_job") def test_remote_job(self, mock_job): mock_job.return_value = None with tempfile.NamedTemporaryFile() as temp_file: task = TestRemoteHadoopJarJob(temp_file.name) task.run() @patch("luigi.contrib.hadoop.run_and_track_hadoop_job") def test_remote_job_with_space_in_task_id(self, mock_job): with tempfile.NamedTemporaryFile() as temp_file: def check_space(arr, task_id): for a in arr: if a.startswith("hadoop jar"): found = False for x in shlex.split(a): if task_id in x: found = True if not found: raise AssertionError task = TestRemoteHadoopJarTwoParamJob(temp_file.name, "test") mock_job.side_effect = lambda x, _: check_space(x, str(task)) task.run() @patch("luigi.contrib.hadoop.run_and_track_hadoop_job") def test_remote_job_missing_config(self, mock_job): mock_job.return_value = None with tempfile.NamedTemporaryFile() as temp_file: task = TestRemoteMissingJarJob(temp_file.name) self.assertRaises(HadoopJarJobError, task.run) ================================================ FILE: test/contrib/hdfs/webhdfs_client_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2015 VNG Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import unittest import pytest from helpers import with_config from luigi.contrib.hdfs import WebHdfsClient InsecureClient = pytest.importorskip("hdfs.InsecureClient") KerberosClient = pytest.importorskip("hdfs.ext.kerberos.KerberosClient") @pytest.mark.apache class TestWebHdfsClient(unittest.TestCase): @with_config({"webhdfs": {"client_type": "insecure"}}) def test_insecure_client_type(self): client = WebHdfsClient(host="localhost").client self.assertIsInstance(client, InsecureClient) @with_config({"webhdfs": {"client_type": "kerberos"}}) def test_kerberos_client_type(self): client = WebHdfsClient(host="localhost").client self.assertIsInstance(client, KerberosClient) ================================================ FILE: test/contrib/hdfs_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import pickle import random import re from target_test import FileSystemTargetTestMixin import luigi import luigi.contrib.hdfs.clients import luigi.format from luigi.contrib import hdfs class ComplexOldFormat(luigi.format.Format): """Should take unicode but output bytes""" def hdfs_writer(self, output_pipe): return self.pipe_writer(luigi.contrib.hdfs.Plain.hdfs_writer(output_pipe)) def pipe_writer(self, output_pipe): return luigi.format.UTF8.pipe_writer(output_pipe) def pipe_reader(self, output_pipe): return output_pipe class TestException(Exception): pass class HdfsTargetTestMixin(FileSystemTargetTestMixin): def create_target(self, format=None): target = hdfs.HdfsTarget(self._test_file(), format=format) if target.exists(): target.remove(skip_trash=True) return target def test_slow_exists(self): target = hdfs.HdfsTarget(self._test_file()) try: target.remove(skip_trash=True) except BaseException: pass self.assertFalse(self.fs.exists(target.path)) target.open("w").close() self.assertTrue(self.fs.exists(target.path)) def should_raise(): self.fs.exists("hdfs://doesnotexist/foo") self.assertRaises(hdfs.HDFSCliError, should_raise) def should_raise_2(): self.fs.exists("hdfs://_doesnotexist_/foo") self.assertRaises(hdfs.HDFSCliError, should_raise_2) def test_create_ancestors(self): parent = self._test_dir() target = hdfs.HdfsTarget("%s/foo/bar/baz" % parent) if self.fs.exists(parent): self.fs.remove(parent, skip_trash=True) self.assertFalse(self.fs.exists(parent)) fobj = target.open("w") fobj.write("lol\n") fobj.close() self.assertTrue(self.fs.exists(parent)) self.assertTrue(target.exists()) def test_tmp_cleanup(self): path = self._test_file() target = hdfs.HdfsTarget(path, is_tmp=True) if target.exists(): target.remove(skip_trash=True) with target.open("w") as fobj: fobj.write("lol\n") self.assertTrue(target.exists()) del target import gc gc.collect() self.assertFalse(self.fs.exists(path)) def test_luigi_tmp(self): target = hdfs.HdfsTarget(is_tmp=True) self.assertFalse(target.exists()) with target.open("w"): pass self.assertTrue(target.exists()) def test_tmp_move(self): target = hdfs.HdfsTarget(is_tmp=True) target2 = hdfs.HdfsTarget(self._test_file()) if target2.exists(): target2.remove(skip_trash=True) with target.open("w"): pass self.assertTrue(target.exists()) target.move(target2.path) self.assertFalse(target.exists()) self.assertTrue(target2.exists()) def test_rename_no_parent(self): parent = self._test_dir() + "/foo" if self.fs.exists(parent): self.fs.remove(parent, skip_trash=True) target1 = hdfs.HdfsTarget(is_tmp=True) target2 = hdfs.HdfsTarget(parent + "/bar") with target1.open("w"): pass self.assertTrue(target1.exists()) target1.move(target2.path) self.assertFalse(target1.exists()) self.assertTrue(target2.exists()) def test_rename_no_grandparent(self): grandparent = self._test_dir() + "/foo" if self.fs.exists(grandparent): self.fs.remove(grandparent, skip_trash=True) target1 = hdfs.HdfsTarget(is_tmp=True) target2 = hdfs.HdfsTarget(grandparent + "/bar/baz") with target1.open("w"): pass self.assertTrue(target1.exists()) target1.move(target2.path) self.assertFalse(target1.exists()) self.assertTrue(target2.exists()) def test_glob_exists(self): target_dir = hdfs.HdfsTarget(self._test_dir()) if target_dir.exists(): target_dir.remove(skip_trash=True) self.fs.mkdir(target_dir.path) t1 = hdfs.HdfsTarget(target_dir.path + "/part-00001") t2 = hdfs.HdfsTarget(target_dir.path + "/part-00002") t3 = hdfs.HdfsTarget(target_dir.path + "/another") with t1.open("w") as f: f.write("foo\n") with t2.open("w") as f: f.write("bar\n") with t3.open("w") as f: f.write("biz\n") files = hdfs.HdfsTarget("%s/part-0000*" % target_dir.path) self.assertTrue(files.glob_exists(2)) self.assertFalse(files.glob_exists(3)) self.assertFalse(files.glob_exists(1)) def assertRegexpMatches(self, text, expected_regexp, msg=None): """Python 2.7 backport.""" if isinstance(expected_regexp, str): expected_regexp = re.compile(expected_regexp) if not expected_regexp.search(text): msg = msg or "Regexp didn't match" msg = "%s: %r not found in %r" % (msg, expected_regexp.pattern, text) raise self.failureException(msg) def test_tmppath_not_configured(self): # Given: several target paths to test path1 = "/dir1/dir2/file" path2 = "hdfs:///dir1/dir2/file" path3 = "hdfs://somehost/dir1/dir2/file" path4 = "file:///dir1/dir2/file" path5 = "/tmp/dir/file" path6 = "file:///tmp/dir/file" path7 = "hdfs://somehost/tmp/dir/file" path8 = None path9 = "/tmpdir/file" # When: I create a temporary path for targets res1 = hdfs.tmppath(path1, include_unix_username=False) res2 = hdfs.tmppath(path2, include_unix_username=False) res3 = hdfs.tmppath(path3, include_unix_username=False) res4 = hdfs.tmppath(path4, include_unix_username=False) res5 = hdfs.tmppath(path5, include_unix_username=False) res6 = hdfs.tmppath(path6, include_unix_username=False) res7 = hdfs.tmppath(path7, include_unix_username=False) res8 = hdfs.tmppath(path8, include_unix_username=False) res9 = hdfs.tmppath(path9, include_unix_username=False) # Then: I should get correct results relative to Luigi temporary directory self.assertRegexpMatches(res1, "^/tmp/dir1/dir2/file-luigitemp-\\d+") # it would be better to see hdfs:///path instead of hdfs:/path, but single slash also works well self.assertRegexpMatches(res2, "^hdfs:/tmp/dir1/dir2/file-luigitemp-\\d+") self.assertRegexpMatches(res3, "^hdfs://somehost/tmp/dir1/dir2/file-luigitemp-\\d+") self.assertRegexpMatches(res4, "^file:///tmp/dir1/dir2/file-luigitemp-\\d+") self.assertRegexpMatches(res5, "^/tmp/dir/file-luigitemp-\\d+") # known issue with duplicated "tmp" if schema is present self.assertRegexpMatches(res6, "^file:///tmp/tmp/dir/file-luigitemp-\\d+") # known issue with duplicated "tmp" if schema is present self.assertRegexpMatches(res7, "^hdfs://somehost/tmp/tmp/dir/file-luigitemp-\\d+") self.assertRegexpMatches(res8, "^/tmp/luigitemp-\\d+") self.assertRegexpMatches(res9, "/tmp/tmpdir/file") def test_tmppath_username(self): self.assertRegexpMatches(hdfs.tmppath("/path/to/stuff", include_unix_username=True), "^/tmp/[a-z0-9_]+/path/to/stuff-luigitemp-\\d+") def test_pickle(self): t = hdfs.HdfsTarget("/tmp/dir") pickle.dumps(t) def test_flag_target(self): target = hdfs.HdfsFlagTarget("/some/dir/", format=format) if target.exists(): target.remove(skip_trash=True) self.assertFalse(target.exists()) t1 = hdfs.HdfsTarget(target.path + "part-00000", format=format) with t1.open("w"): pass t2 = hdfs.HdfsTarget(target.path + "_SUCCESS", format=format) with t2.open("w"): pass self.assertTrue(target.exists()) def test_flag_target_fails_if_not_directory(self): with self.assertRaises(ValueError): hdfs.HdfsFlagTarget("/home/file.txt") class _MiscOperationsMixin: # TODO: chown/chmod/count should really be methods on HdfsTarget rather than the client! def get_target(self): fn = "/tmp/foo-%09d" % random.randint(0, 999999999) t = luigi.contrib.hdfs.HdfsTarget(fn) with t.open("w") as f: f.write("test") return t def test_count(self): t = self.get_target() res = self.get_client().count(t.path) for key in ["content_size", "dir_count", "file_count"]: self.assertTrue(key in res) def test_chmod(self): t = self.get_target() self.get_client().chmod(t.path, "777") def test_chown(self): t = self.get_target() self.get_client().chown(t.path, "root", "root") ================================================ FILE: test/contrib/hive_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import os import sys import tempfile from collections import OrderedDict import mock import pytest from helpers import unittest import luigi.contrib.hive from luigi import LocalTarget @pytest.mark.apache class HiveTest(unittest.TestCase): count = 0 def mock_hive_cmd(self, args, check_return=True): self.last_hive_cmd = args self.count += 1 return "statement{}".format(self.count) def setUp(self): self.run_hive_cmd_saved = luigi.contrib.hive.run_hive luigi.contrib.hive.run_hive = self.mock_hive_cmd def tearDown(self): luigi.contrib.hive.run_hive = self.run_hive_cmd_saved def test_run_hive_command(self): pre_count = self.count res = luigi.contrib.hive.run_hive_cmd("foo") self.assertEqual(["-e", "foo"], self.last_hive_cmd) self.assertEqual("statement{0}".format(pre_count + 1), res) def test_run_hive_script_not_exists(self): def test(): luigi.contrib.hive.run_hive_script("/tmp/some-non-existant-file______") self.assertRaises(RuntimeError, test) def test_run_hive_script_exists(self): with tempfile.NamedTemporaryFile(delete=True) as f: pre_count = self.count res = luigi.contrib.hive.run_hive_script(f.name) self.assertEqual(["-f", f.name], self.last_hive_cmd) self.assertEqual("statement{0}".format(pre_count + 1), res) def test_create_parent_dirs(self): dirname = "/tmp/hive_task_test_dir" class FooHiveTask: def output(self): return LocalTarget(os.path.join(dirname, "foo")) runner = luigi.contrib.hive.HiveQueryRunner() runner.prepare_outputs(FooHiveTask()) self.assertTrue(os.path.exists(dirname)) @pytest.mark.apache class HiveCommandClientTest(unittest.TestCase): """Note that some of these tests are really for the CDH releases of Hive, to which I do not currently have access. Hopefully there are no significant differences in the expected output""" def setUp(self): self.client = luigi.contrib.hive.HiveCommandClient() self.apacheclient = luigi.contrib.hive.ApacheHiveCommandClient() self.metastoreclient = luigi.contrib.hive.MetastoreClient() @mock.patch("luigi.contrib.hive.run_hive_cmd") def test_default_table_location(self, run_command): run_command.return_value = ( "Protect Mode: None \n" "Retention: 0 \n" "Location: hdfs://localhost:9000/user/hive/warehouse/mytable \n" "Table Type: MANAGED_TABLE \n" ) returned = self.client.table_location("mytable") self.assertEqual("hdfs://localhost:9000/user/hive/warehouse/mytable", returned) @mock.patch("luigi.contrib.hive.run_hive_cmd") def test_table_exists(self, run_command): run_command.return_value = "OK" returned = self.client.table_exists("mytable") self.assertFalse(returned) run_command.return_value = "OK\nmytable" returned = self.client.table_exists("mytable") self.assertTrue(returned) # Issue #896 test case insensitivity returned = self.client.table_exists("MyTable") self.assertTrue(returned) run_command.return_value = "day=2013-06-28/hour=3\nday=2013-06-28/hour=4\nday=2013-07-07/hour=2\n" self.client.partition_spec = mock.Mock(name="partition_spec") self.client.partition_spec.return_value = "somepart" returned = self.client.table_exists("mytable", partition={"a": "b"}) self.assertTrue(returned) run_command.return_value = "" returned = self.client.table_exists("mytable", partition={"a": "b"}) self.assertFalse(returned) @mock.patch("luigi.contrib.hive.run_hive_cmd") def test_table_schema(self, run_command): run_command.return_value = "FAILED: SemanticException [Error 10001]: blah does not exist\nSome other stuff" returned = self.client.table_schema("mytable") self.assertFalse(returned) run_command.return_value = ( "OK\n" "col1 string None \n" "col2 string None \n" "col3 string None \n" "day string None \n" "hour smallint None \n\n" "# Partition Information \n" "# col_name data_type comment \n\n" "day string None \n" "hour smallint None \n" "Time taken: 2.08 seconds, Fetched: 34 row(s)\n" ) expected = [ ("OK",), ("col1", "string", "None"), ("col2", "string", "None"), ("col3", "string", "None"), ("day", "string", "None"), ("hour", "smallint", "None"), ("",), ("# Partition Information",), ("# col_name", "data_type", "comment"), ("",), ("day", "string", "None"), ("hour", "smallint", "None"), ("Time taken: 2.08 seconds, Fetched: 34 row(s)",), ] returned = self.client.table_schema("mytable") self.assertEqual(expected, returned) def test_partition_spec(self): returned = self.client.partition_spec({"a": "b", "c": "d"}) self.assertEqual("`a`='b',`c`='d'", returned) @mock.patch("luigi.contrib.hive.run_hive_cmd") def test_apacheclient_table_exists(self, run_command): run_command.return_value = "OK" returned = self.apacheclient.table_exists("mytable") self.assertFalse(returned) run_command.return_value = "OK\nmytable" returned = self.apacheclient.table_exists("mytable") self.assertTrue(returned) # Issue #896 test case insensitivity returned = self.apacheclient.table_exists("MyTable") self.assertTrue(returned) run_command.return_value = "day=2013-06-28/hour=3\nday=2013-06-28/hour=4\nday=2013-07-07/hour=2\n" self.apacheclient.partition_spec = mock.Mock(name="partition_spec") self.apacheclient.partition_spec.return_value = "somepart" returned = self.apacheclient.table_exists("mytable", partition={"a": "b"}) self.assertTrue(returned) run_command.return_value = "" returned = self.apacheclient.table_exists("mytable", partition={"a": "b"}) self.assertFalse(returned) @mock.patch("luigi.contrib.hive.run_hive_cmd") def test_apacheclient_table_schema(self, run_command): run_command.return_value = "FAILED: SemanticException [Error 10001]: Table not found mytable\nSome other stuff" returned = self.apacheclient.table_schema("mytable") self.assertFalse(returned) run_command.return_value = ( "OK\n" "col1 string None \n" "col2 string None \n" "col3 string None \n" "day string None \n" "hour smallint None \n\n" "# Partition Information \n" "# col_name data_type comment \n\n" "day string None \n" "hour smallint None \n" "Time taken: 2.08 seconds, Fetched: 34 row(s)\n" ) expected = [ ("OK",), ("col1", "string", "None"), ("col2", "string", "None"), ("col3", "string", "None"), ("day", "string", "None"), ("hour", "smallint", "None"), ("",), ("# Partition Information",), ("# col_name", "data_type", "comment"), ("",), ("day", "string", "None"), ("hour", "smallint", "None"), ("Time taken: 2.08 seconds, Fetched: 34 row(s)",), ] returned = self.apacheclient.table_schema("mytable") self.assertEqual(expected, returned) @mock.patch("luigi.contrib.hive.HiveThriftContext") def test_metastoreclient_partition_existence_regardless_of_order(self, thrift_context): thrift_context.return_value = thrift_context client_mock = mock.Mock(name="clientmock") client_mock.return_value = client_mock thrift_context.__enter__ = client_mock client_mock.get_partition_names = mock.Mock(return_value=["p1=x/p2=y", "p1=a/p2=b"]) partition_spec = OrderedDict([("p1", "a"), ("p2", "b")]) self.assertTrue(self.metastoreclient.table_exists("table", "default", partition_spec)) partition_spec = OrderedDict([("p2", "b"), ("p1", "a")]) self.assertTrue(self.metastoreclient.table_exists("table", "default", partition_spec)) def test_metastore_partition_spec_has_the_same_order(self): partition_spec = OrderedDict([("p1", "a"), ("p2", "b")]) spec_string = luigi.contrib.hive.MetastoreClient().partition_spec(partition_spec) self.assertEqual(spec_string, "p1=a/p2=b") partition_spec = OrderedDict([("p2", "b"), ("p1", "a")]) spec_string = luigi.contrib.hive.MetastoreClient().partition_spec(partition_spec) self.assertEqual(spec_string, "p1=a/p2=b") @mock.patch("luigi.configuration") def test_client_def(self, hive_syntax): hive_syntax.get_config.return_value.get.return_value = "cdh4" client = luigi.contrib.hive.get_default_client() self.assertEqual(luigi.contrib.hive.HiveCommandClient, type(client)) hive_syntax.get_config.return_value.get.return_value = "cdh3" client = luigi.contrib.hive.get_default_client() self.assertEqual(luigi.contrib.hive.HiveCommandClient, type(client)) hive_syntax.get_config.return_value.get.return_value = "apache" client = luigi.contrib.hive.get_default_client() self.assertEqual(luigi.contrib.hive.ApacheHiveCommandClient, type(client)) hive_syntax.get_config.return_value.get.return_value = "metastore" client = luigi.contrib.hive.get_default_client() self.assertEqual(luigi.contrib.hive.MetastoreClient, type(client)) hive_syntax.get_config.return_value.get.return_value = "warehouse" client = luigi.contrib.hive.get_default_client() self.assertEqual(luigi.contrib.hive.WarehouseHiveClient, type(client)) @mock.patch("subprocess.Popen") def test_run_hive_command(self, popen): # I'm testing this again to check the return codes # I didn't want to tear up all the existing tests to change how run_hive is mocked comm = mock.Mock(name="communicate_mock") comm.return_value = b"some return stuff", "" preturn = mock.Mock(name="open_mock") preturn.returncode = 0 preturn.communicate = comm popen.return_value = preturn returned = luigi.contrib.hive.run_hive(["blah", "blah"]) self.assertEqual("some return stuff", returned) preturn.returncode = 17 self.assertRaises(luigi.contrib.hive.HiveCommandError, luigi.contrib.hive.run_hive, ["blah", "blah"]) comm.return_value = b"", "some stderr stuff" returned = luigi.contrib.hive.run_hive(["blah", "blah"], False) self.assertEqual("", returned) class WarehouseHiveClientTest(unittest.TestCase): def test_table_exists_files_actually_exist(self): # arrange hdfs_client = mock.Mock(name="hdfs_client") hdfs_client.exists.return_value = True hdfs_client.listdir.return_value = ["00000_0", "00000_1", "00000_2", ".tmp/"] warehouse_hive_client = luigi.contrib.hive.WarehouseHiveClient(hdfs_client=hdfs_client, warehouse_location="/apps/hive/warehouse") # act exists = warehouse_hive_client.table_exists(database="some_db", table="table_name", partition=OrderedDict(a=1, b=2)) # assert assert exists hdfs_client.exists.assert_called_once_with("/apps/hive/warehouse/some_db.db/table_name/a=1/b=2") @mock.patch("luigi.configuration") def test_table_exists_without_partition_spec_files_actually_exist(self, warehouse_location): # arrange warehouse_location.get_config.return_value.get.return_value = "/apps/hive/warehouse" hdfs_client = mock.Mock(name="hdfs_client") hdfs_client.exists.return_value = True hdfs_client.listdir.return_value = ["00000_0", "00000_1", "00000_2", ".tmp/"] warehouse_hive_client = luigi.contrib.hive.WarehouseHiveClient( hdfs_client=hdfs_client, ) # act exists = warehouse_hive_client.table_exists( database="some_db", table="table_name", ) # assert assert exists hdfs_client.exists.assert_called_once_with("/apps/hive/warehouse/some_db.db/table_name/") hdfs_client.listdir.assert_called_once_with("/apps/hive/warehouse/some_db.db/table_name/") @mock.patch("luigi.configuration") def test_table_exists_only_tmp_files_exist(self, ignored_file_masks): # arrange ignored_file_masks.get_config.return_value.get.return_value = r"(\.tmp.*)" hdfs_client = mock.Mock(name="hdfs_client") hdfs_client.exists.return_value = True hdfs_client.listdir.return_value = [".tmp/"] warehouse_hive_client = luigi.contrib.hive.WarehouseHiveClient(hdfs_client=hdfs_client, warehouse_location="/apps/hive/warehouse") # act exists = warehouse_hive_client.table_exists(database="some_db", table="table_name", partition={"a": 1}) # assert assert not exists hdfs_client.exists.assert_called_once_with("/apps/hive/warehouse/some_db.db/table_name/a=1") hdfs_client.listdir.assert_called_once_with("/apps/hive/warehouse/some_db.db/table_name/a=1") @mock.patch("luigi.configuration") def test_table_exists_ambiguous_partition(self, ignored_file_masks): # arrange ignored_file_masks.get_config.return_value.get.return_value = r"(\.tmp.*)" hdfs_client = mock.Mock(name="hdfs_client") hdfs_client.exists.return_value = True hdfs_client.listdir.return_value = [".tmp/"] warehouse_hive_client = luigi.contrib.hive.WarehouseHiveClient(hdfs_client=hdfs_client, warehouse_location="/apps/hive/warehouse") def _call_exists(): return warehouse_hive_client.table_exists(database="some_db", table="table_name", partition={"a": 1, "b": 2}) # act & assert if sys.version_info >= (3, 7): exists = _call_exists() assert not exists hdfs_client.exists.assert_called_once_with("/apps/hive/warehouse/some_db.db/table_name/a=1/b=2") hdfs_client.listdir.assert_called_once_with("/apps/hive/warehouse/some_db.db/table_name/a=1/b=2") else: self.assertRaises(ValueError, _call_exists) class MyHiveTask(luigi.contrib.hive.HiveQueryTask): param = luigi.Parameter() def query(self): return "banana banana %s" % self.param @pytest.mark.apache class TestHiveTask(unittest.TestCase): task_class = MyHiveTask @mock.patch("luigi.contrib.hadoop.run_and_track_hadoop_job") def test_run(self, run_and_track_hadoop_job): success = luigi.run([self.task_class.__name__, "--param", "foo", "--local-scheduler", "--no-lock"]) self.assertTrue(success) self.assertEqual("hive", run_and_track_hadoop_job.call_args[0][0][0]) class MyHiveTaskArgs(MyHiveTask): def hivevars(self): return {"my_variable1": "value1", "my_variable2": "value2"} def hiveconfs(self): return {"hive.additional.conf": "conf_value"} class TestHiveTaskArgs(TestHiveTask): task_class = MyHiveTaskArgs def test_arglist(self): task = self.task_class(param="foo") f_name = "my_file" runner = luigi.contrib.hive.HiveQueryRunner() arglist = runner.get_arglist(f_name, task) f_idx = arglist.index("-f") self.assertEqual(arglist[f_idx + 1], f_name) hivevars = ["{}={}".format(k, v) for k, v in task.hivevars().items()] for var in hivevars: idx = arglist.index(var) self.assertEqual(arglist[idx - 1], "--hivevar") hiveconfs = ["{}={}".format(k, v) for k, v in task.hiveconfs().items()] for conf in hiveconfs: idx = arglist.index(conf) self.assertEqual(arglist[idx - 1], "--hiveconf") @pytest.mark.apache class TestHiveTarget(unittest.TestCase): def test_hive_table_target(self): client = mock.Mock() target = luigi.contrib.hive.HiveTableTarget(database="db", table="foo", client=client) target.exists() client.table_exists.assert_called_with("foo", "db", None) def test_hive_partition_target(self): client = mock.Mock() target = luigi.contrib.hive.HivePartitionTarget(database="db", table="foo", partition="bar", client=client) target.exists() client.table_exists.assert_called_with("foo", "db", "bar") class ExternalHiveTaskTest(unittest.TestCase): def test_table(self): # arrange class _Task(luigi.contrib.hive.ExternalHiveTask): database = "schema1" table = "table1" # act output = _Task().output() # assert assert isinstance(output, luigi.contrib.hive.HivePartitionTarget) assert output.database == "schema1" assert output.table == "table1" assert output.partition == {} def test_partition_exists(self): # arrange class _Task(luigi.contrib.hive.ExternalHiveTask): database = "schema2" table = "table2" partition = {"a": 1} # act output = _Task().output() # assert assert isinstance(output, luigi.contrib.hive.HivePartitionTarget) assert output.database == "schema2" assert output.table == "table2" assert output.partition == {"a": 1} ================================================ FILE: test/contrib/kubernetes_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2015 Outlier Bio, LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ Tests for the Kubernetes Job wrapper. Requires: - pykube: ``pip install pykube-ng`` - A local minikube custer up and running: http://kubernetes.io/docs/getting-started-guides/minikube/ **WARNING**: For Python versions < 3.5 the kubeconfig file must point to a Kubernetes API hostname, and NOT to an IP address. Written and maintained by Marco Capuccini (@mcapuccini). """ import logging import unittest import mock import pytest import luigi from luigi.contrib.kubernetes import KubernetesJobTask logger = logging.getLogger("luigi-interface") try: from pykube.config import KubeConfig from pykube.http import HTTPClient from pykube.objects import Job except ImportError: raise unittest.SkipTest("pykube is not installed. This test requires pykube.") class SuccessJob(KubernetesJobTask): name = "success" spec_schema = {"containers": [{"name": "hello", "image": "alpine:3.4", "command": ["echo", "Hello World!"]}]} class FailJob(KubernetesJobTask): name = "fail" max_retrials = 3 backoff_limit = 3 spec_schema = {"containers": [{"name": "fail", "image": "alpine:3.4", "command": ["You", "Shall", "Not", "Pass"]}]} @property def labels(self): return {"dummy_label": "dummy_value"} @pytest.mark.contrib class TestK8STask(unittest.TestCase): def test_success_job(self): success = luigi.run(["SuccessJob", "--local-scheduler"]) self.assertTrue(success) def test_fail_job(self): fail = FailJob() self.assertRaises(RuntimeError, fail.run) # Check for retrials kube_api = HTTPClient(KubeConfig.from_file("~/.kube/config")) # assumes minikube jobs = Job.objects(kube_api).filter(selector="luigi_task_id=" + fail.job_uuid) self.assertEqual(len(jobs.response["items"]), 1) job = Job(kube_api, jobs.response["items"][0]) self.assertTrue("failed" in job.obj["status"]) self.assertTrue(job.obj["status"]["failed"] > fail.max_retrials) self.assertTrue(job.obj["spec"]["template"]["metadata"]["labels"] == fail.labels()) @mock.patch.object(KubernetesJobTask, "_KubernetesJobTask__get_job_status") @mock.patch.object(KubernetesJobTask, "signal_complete") def test_output(self, mock_signal, mock_job_status): # mock that the job succeeded mock_job_status.return_value = "succeeded" # create a kubernetes job kubernetes_job = KubernetesJobTask() # set logger and uu_name due to logging in __track_job() kubernetes_job._KubernetesJobTask__logger = logger kubernetes_job.uu_name = "test" # track the job (bc included in run method) kubernetes_job._KubernetesJobTask__track_job() # Make sure successful job signals self.assertTrue(mock_signal.called) ================================================ FILE: test/contrib/lsf_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ LSF Unit Test ============= Test runner for the LSF wrapper. The test is based on the one used for the SGE wrappers """ import logging import os import os.path import subprocess import unittest from glob import glob import pytest from mock import patch import luigi from luigi.contrib.lsf import LSFJobTask DEFAULT_HOME = "" LOGGER = logging.getLogger("luigi-interface") # BJOBS_OUTPUT = """JOBID USER STAT QUEUE FROM_HOST EXEC_HOST JOB_NAME SUBMIT_TIME # 1000001 mcdowal RUN production sub-node-002 node4-123 /bin/bash Mar 14 10:10 # 1000002 mcdowal PEND production sub-node-002 node5-269 /bin/bash Mar 14 10:10 # 1000003 mcdowal EXIT production sub-node-002 /bin/bash Mar 14 10:10 # """ def on_lsf_master(): try: subprocess.check_call("bjobs", shell=True) return True except subprocess.CalledProcessError: return False class TestJobTask(LSFJobTask): """Simple SGE job: write a test file to NSF shared drive and waits a minute""" i = luigi.Parameter() def work(self): LOGGER.info("Running test job...") with open(self.output().path, "w") as f: f.write("this is a test\n") def output(self): return luigi.LocalTarget(os.path.join(DEFAULT_HOME, "test_lsf_file_" + str(self.i))) @pytest.mark.contrib class TestSGEJob(unittest.TestCase): """Test from SGE master node""" @patch("subprocess.Popen") @patch("subprocess.Popen.communicate") def test_run_job(self, mock_open, mock_communicate): if on_lsf_master(): outfile = os.path.join(DEFAULT_HOME, "testfile_1") tasks = [TestJobTask(i=str(i), n_cpu_flag=1) for i in range(3)] luigi.build(tasks, local_scheduler=True, workers=3) self.assertTrue(os.path.exists(outfile)) @patch("subprocess.Popen") @patch("subprocess.Popen.communicate") def test_run_job_with_dump(self, mock_open, mock_communicate): mock_open.side_effect = ["Job <1000001> is submitted to queue .", ""] task = TestJobTask(i=str(1), n_cpu_flag=1, shared_tmp_dir="/tmp") luigi.build([task], local_scheduler=True) self.assertEqual(mock_open.call_count, 0) def tearDown(self): for fpath in glob(os.path.join(DEFAULT_HOME, "test_lsf_file_*")): try: os.remove(fpath) except OSError: pass if __name__ == "__main__": unittest.main() ================================================ FILE: test/contrib/mongo_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2017 Big Datext Inc # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import pytest from helpers import unittest from luigi.contrib.mongodb import MongoCellTarget, MongoRangeTarget HOST = "localhost" PORT = 27017 INDEX = "luigi_test" COLLECTION = "luigi_collection" try: import pymongo mongo_client = pymongo.MongoClient(HOST, PORT) mongo_client.server_info() except ImportError: raise unittest.SkipTest("Unable to load pymongo module") except Exception: raise unittest.SkipTest("Unable to connect to local mongoDB instance") @pytest.mark.contrib class MongoCellTargetTest(unittest.TestCase): """MongoCellTarget unittest on local test database""" def setUp(self): """ Fill test database with fake data """ self.mongo_client = pymongo.MongoClient(HOST, PORT) self.collection = self.mongo_client[INDEX][COLLECTION] self.collection.delete_many({}) test_docs = [ {"_id": "person_1", "name": "Mike", "infos": {"family": "single"}}, {"_id": "person_2", "name": "Laura", "surname": "Gilmore"}, {"_id": "person_3", "surname": "Specter"}, {"_id": "person_4", "surname": "", "infos": {"family": {"children": ["jack", "rose"]}}}, ] self.collection.insert_many(test_docs) def tearDown(self): """ Make sure the test database is in clean state """ self.collection.drop() self.mongo_client.drop_database(INDEX) def test_exists(self): test_values = [ ("person_1", "surname", False), ("person_2", "surname", True), ("person_3", "surname", True), ("unknow_person", "surname", False), ] for id_, field, result in test_values: target = MongoCellTarget(self.mongo_client, INDEX, COLLECTION, id_, field) self.assertEqual(result, target.exists()) def test_exists_nested(self): test_values = [ ("person_1", "infos", True), ("person_1", "infos.family", True), ("person_2", "family", False), ("person_4", "infos", True), ("person_4", "infos.family", True), ("person_4", "infos.sexe", False), ("person_4", "infos.family.children", True), ("person_4", "infos.family.aunt", False), ] for id_, path, result in test_values: target = MongoCellTarget(self.mongo_client, INDEX, COLLECTION, id_, path) self.assertEqual(result, target.exists()) def test_read(self): test_values = [ ("person_1", "surname", None), ("person_2", "surname", "Gilmore"), ("person_3", "surname", "Specter"), ("person_4", "surname", ""), ("unknown_person", "surname", None), ] for id_, field, result in test_values: target = MongoCellTarget(self.mongo_client, INDEX, COLLECTION, id_, field) self.assertEqual(result, target.read()) def test_read_nested(self): test_values = [ ("person_1", "infos", {"family": "single"}), ("person_1", "infos.family", "single"), ("person_2", "family", None), ("person_4", "infos", {"family": {"children": ["jack", "rose"]}}), ("person_4", "infos.family", {"children": ["jack", "rose"]}), ("person_4", "infos.sexe", None), ("person_4", "infos.family.children", ["jack", "rose"]), ] for id_, path, result in test_values: target = MongoCellTarget(self.mongo_client, INDEX, COLLECTION, id_, path) self.assertEqual(result, target.read()) def test_write(self): ids = ["person_1", "person_2", "person_3", "person_4", "unknow_person"] for id_ in ids: self.setUp() target = MongoCellTarget(self.mongo_client, INDEX, COLLECTION, id_, "age") target.write("100") self.assertEqual(target.read(), "100") def test_write_nested(self): test_values = [ ("person_1", "infos", 12), ("person_1", "infos.family", ["ambre", "justin", "sophia"]), ("person_2", "hobbies", {"soccer": True}), ("person_3", "infos", {"age": "100"}), ("person_3", "infos.hobbies", {"soccer": True}), ("person_3", "infos.hobbies.soccer", [{"status": "young"}, "strong", "fast"]), ] for id_, path, new_value in test_values: self.setUp() target = MongoCellTarget(self.mongo_client, INDEX, COLLECTION, id_, path) target.write(new_value) self.assertEqual(target.read(), new_value) self.tearDown() @pytest.mark.contrib class MongoRangerTargetTest(unittest.TestCase): """MongoRangelTarget unittest on local test database""" def setUp(self): """ Fill test database with fake data """ self.mongo_client = pymongo.MongoClient(HOST, PORT) self.collection = self.mongo_client[INDEX][COLLECTION] self.collection.delete_many({}) test_docs = [ {"_id": "person_1", "age": 11, "experience": 10, "content": "Lorem ipsum, dolor sit amet. Consectetur adipiscing elit."}, {"_id": "person_2", "age": 12, "experience": 22, "content": "Sed purus nisl. Faucibus in, erat eu. Rhoncus mattis velit."}, {"_id": "person_3", "age": 13, "content": "Nulla malesuada, fringilla lorem at pellentesque."}, {"_id": "person_4", "age": 14, "content": "Curabitur condimentum. Venenatis fringilla."}, ] self.collection.insert_many(test_docs) def tearDown(self): """ Make sure the test database is in clean state """ self.collection.drop() self.mongo_client.drop_database(INDEX) def test_exists(self): test_values = [ ("age", [], True), ("age", ["person_1", "person_2", "person_3"], True), ("experience", ["person_1", "person_2", "person_3", "person_4"], False), ("experience", ["person_1", "person_2"], True), ("unknow_field", ["person_1", "person_2"], False), ("experience", ["unknow_person"], False), ("experience", ["person_1", "unknown_person"], False), ("experience", ["person_3", "unknown_person"], False), ] for field, ids, result in test_values: target = MongoRangeTarget(self.mongo_client, INDEX, COLLECTION, ids, field) self.assertEqual(result, target.exists()) def test_read(self): test_values = [ ("age", [], {}), ("age", ["unknown_person"], {}), ("age", ["person_1", "person_3"], {"person_1": 11, "person_3": 13}), ("age", ["person_1", "person_3", "person_5"], {"person_1": 11, "person_3": 13}), ("experience", ["person_1", "person_3"], {"person_1": 10}), ("experience", ["person_1", "person_3", "person_5"], {"person_1": 10}), ] for field, ids, result in test_values: target = MongoRangeTarget(self.mongo_client, INDEX, COLLECTION, ids, field) self.assertEqual(result, target.read()) def test_write(self): test_values = [ ( "age", # feature ["person_1"], # ids {"person_1": 31}, # arg of write() ({"_id": {"$in": ["person_1"]}}, {"age": True}), # mongo request to fetch result [{"_id": "person_1", "age": 31}], # result ), ( "experience", ["person_1", "person_3"], {"person_1": 31, "person_3": 32}, ({"_id": {"$in": ["person_1", "person_3"]}}, {"experience": True}), [{"_id": "person_1", "experience": 31}, {"_id": "person_3", "experience": 32}], ), ( "experience", [], {"person_3": 18}, ({"_id": {"$in": ["person_1", "person_3"]}}, {"experience": True}), [{"_id": "person_1", "experience": 10}, {"_id": "person_3"}], ), ( "age", ["person_1"], {"person_1": ["young", "old"]}, ({"_id": "person_1"}, {"age": True}), [{"_id": "person_1", "age": ["young", "old"]}], ), ( "age", ["person_1"], {"person_1": {"feeling_like": 60}}, ({"_id": "person_1"}, {"age": True}), [{"_id": "person_1", "age": {"feeling_like": 60}}], ), ( "age", ["person_1"], {"person_1": [{"feeling_like": 60}, 24]}, ({"_id": "person_1"}, {"age": True}), [{"_id": "person_1", "age": [{"feeling_like": 60}, 24]}], ), ] for field, ids, docs, req, result in test_values: self.setUp() target = MongoRangeTarget(self.mongo_client, INDEX, COLLECTION, ids, field) target.write(docs) self.assertEqual(result, list(self.collection.find(*req))) self.tearDown() ================================================ FILE: test/contrib/mysqldb_test.py ================================================ import datetime import mock import pytest from helpers import unittest import luigi.contrib.mysqldb from luigi.tools.range import RangeDaily def datetime_to_epoch(dt): td = dt - datetime.datetime(1970, 1, 1) return td.days * 86400 + td.seconds + td.microseconds / 1e6 class MockMysqlCursor(mock.Mock): """ Keeps state to simulate executing SELECT queries and fetching results. """ def __init__(self, existing_update_ids): super(MockMysqlCursor, self).__init__() self.existing = existing_update_ids def execute(self, query, params): if query.startswith("SELECT 1 FROM table_updates"): self.fetchone_result = (1,) if params[0] in self.existing else None else: self.fetchone_result = None def fetchone(self): return self.fetchone_result class DummyMysqlImporter(luigi.contrib.mysqldb.CopyToTable): date = luigi.DateParameter() host = "dummy_host" database = "dummy_database" user = "dummy_user" password = "dummy_password" table = "dummy_table" columns = ( ("some_text", "text"), ("some_int", "int"), ) # Testing that an existing update will not be run in RangeDaily @pytest.mark.mysql class DailyCopyToTableTest(unittest.TestCase): @mock.patch("mysql.connector.connect") def test_bulk_complete(self, mock_connect): mock_cursor = MockMysqlCursor( [ # Existing update_ids DummyMysqlImporter(date=datetime.datetime(2015, 1, 3)).task_id ] ) mock_connect.return_value.cursor.return_value = mock_cursor task = RangeDaily(of=DummyMysqlImporter, start=datetime.date(2015, 1, 2), now=datetime_to_epoch(datetime.datetime(2015, 1, 7))) actual = sorted([t.task_id for t in task.requires()]) self.assertEqual( actual, sorted( [ DummyMysqlImporter(date=datetime.datetime(2015, 1, 2)).task_id, DummyMysqlImporter(date=datetime.datetime(2015, 1, 4)).task_id, DummyMysqlImporter(date=datetime.datetime(2015, 1, 5)).task_id, DummyMysqlImporter(date=datetime.datetime(2015, 1, 6)).task_id, ] ), ) self.assertFalse(task.complete()) @pytest.mark.mysql class TestCopyToTableWithMetaColumns(unittest.TestCase): @mock.patch("luigi.contrib.mysqldb.CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) @mock.patch("luigi.contrib.mysqldb.CopyToTable._add_metadata_columns") @mock.patch("luigi.contrib.mysqldb.CopyToTable.post_copy_metacolumns") @mock.patch("luigi.contrib.mysqldb.CopyToTable.rows", return_value=["row1", "row2"]) @mock.patch("luigi.contrib.mysqldb.MySqlTarget") @mock.patch("mysql.connector.connect") def test_copy_with_metadata_columns_enabled( self, mock_connect, mock_mysql_target, mock_rows, mock_add_columns, mock_update_columns, mock_metadata_columns_enabled ): task = DummyMysqlImporter(date=datetime.datetime(1991, 3, 24)) mock_cursor = MockMysqlCursor([task.task_id]) mock_connect.return_value.cursor.return_value = mock_cursor task = DummyMysqlImporter(date=datetime.datetime(1991, 3, 24)) task.run() self.assertTrue(mock_add_columns.called) self.assertTrue(mock_update_columns.called) @mock.patch("luigi.contrib.mysqldb.CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=False) @mock.patch("luigi.contrib.mysqldb.CopyToTable._add_metadata_columns") @mock.patch("luigi.contrib.mysqldb.CopyToTable.post_copy_metacolumns") @mock.patch("luigi.contrib.mysqldb.CopyToTable.rows", return_value=["row1", "row2"]) @mock.patch("luigi.contrib.mysqldb.MySqlTarget") @mock.patch("mysql.connector.connect") def test_copy_with_metadata_columns_disabled( self, mock_connect, mock_mysql_target, mock_rows, mock_add_columns, mock_update_columns, mock_metadata_columns_enabled ): task = DummyMysqlImporter(date=datetime.datetime(1991, 3, 24)) mock_cursor = MockMysqlCursor([task.task_id]) mock_connect.return_value.cursor.return_value = mock_cursor task.run() self.assertFalse(mock_add_columns.called) self.assertFalse(mock_update_columns.called) ================================================ FILE: test/contrib/opener_test.py ================================================ import random import unittest import mock import pytest import luigi from luigi.contrib.opener import NoOpenerError, OpenerTarget from luigi.local_target import LocalTarget from luigi.mock import MockTarget @pytest.mark.contrib class TestOpenerTarget(unittest.TestCase): def setUp(self): MockTarget.fs.clear() self.local_file = "/tmp/{}/xyz/test.txt".format(random.randint(0, 999999999)) if LocalTarget.fs.exists(self.local_file): LocalTarget.fs.remove(self.local_file) def tearDown(self): if LocalTarget.fs.exists(self.local_file): LocalTarget.fs.remove(self.local_file) def test_invalid_target(self): """Verify invalid types raises NoOpenerError""" self.assertRaises(NoOpenerError, OpenerTarget, "foo://bar.txt") def test_mock_target(self): """Verify mock target url""" target = OpenerTarget("mock://foo/bar.txt") self.assertEqual(type(target), MockTarget) # Write to the target target.open("w").close() self.assertTrue(MockTarget.fs.exists("foo/bar.txt")) def test_mock_target_root(self): """Verify mock target url""" target = OpenerTarget("mock:///foo/bar.txt") self.assertEqual(type(target), MockTarget) # Write to the target target.open("w").close() self.assertTrue(MockTarget.fs.exists("/foo/bar.txt")) def test_default_target(self): """Verify default local target url""" target = OpenerTarget(self.local_file) self.assertEqual(type(target), LocalTarget) # Write to the target target.open("w").close() self.assertTrue(LocalTarget.fs.exists(self.local_file)) def test_local_target(self): """Verify basic local target url""" local_file = "file://{}".format(self.local_file) target = OpenerTarget(local_file) self.assertEqual(type(target), LocalTarget) # Write to the target target.open("w").close() self.assertTrue(LocalTarget.fs.exists(self.local_file)) @mock.patch("luigi.local_target.LocalTarget.__init__") @mock.patch("luigi.local_target.LocalTarget.__del__") def test_local_tmp_target(self, lt_del_patch, lt_init_patch): """Verify local target url with query string""" lt_init_patch.return_value = None lt_del_patch.return_value = None local_file = "file://{}?is_tmp".format(self.local_file) OpenerTarget(local_file) lt_init_patch.assert_called_with(self.local_file, is_tmp=True) @mock.patch("luigi.contrib.s3.S3Target.__init__") def test_s3_parse(self, s3_init_patch): """Verify basic s3 target url""" s3_init_patch.return_value = None local_file = "s3://zefr/foo/bar.txt" OpenerTarget(local_file) s3_init_patch.assert_called_with("s3://zefr/foo/bar.txt") @mock.patch("luigi.contrib.s3.S3Target.__init__") def test_s3_parse_param(self, s3_init_patch): """Verify s3 target url with params""" s3_init_patch.return_value = None local_file = "s3://zefr/foo/bar.txt?foo=hello&bar=true" OpenerTarget(local_file) s3_init_patch.assert_called_with("s3://zefr/foo/bar.txt", foo="hello", bar="true") def test_binary_support(self): """ Make sure keyword arguments are preserved through the OpenerTarget """ # Verify we can't normally write binary data fp = OpenerTarget("mock://file.txt").open("w") self.assertRaises(TypeError, fp.write, b"\x07\x08\x07") # Verify the format is passed to the target and write binary data fp = OpenerTarget("mock://file.txt", format=luigi.format.MixedUnicodeBytes).open("w") fp.write(b"\x07\x08\x07") fp.close() ================================================ FILE: test/contrib/pai_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2017 Open Targets # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ Tests for OpenPAI wrapper for Luigi. Written and maintained by Liu, Dongqing (@liudongqing). """ import logging import time import responses from helpers import unittest import luigi from luigi.contrib.pai import PaiTask, TaskRole logging.basicConfig(level=logging.DEBUG) """ The following configurations are required to run the test [OpenPai] pai_url:http://host:port/ username:admin password:admin-password expiration:3600 """ class SklearnJob(PaiTask): image = "openpai/pai.example.sklearn" name = "test_job_sk_{0}".format(time.time()) command = "cd scikit-learn/benchmarks && python bench_mnist.py" virtual_cluster = "spark" tasks = [TaskRole("test", "cd scikit-learn/benchmarks && python bench_mnist.py", memoryMB=4096)] class TestPaiTask(unittest.TestCase): @responses.activate def test_success(self): """ Here using the responses lib to mock the PAI rest api call, the following specify the response of the call. """ responses.add(responses.POST, "http://127.0.0.1:9186/api/v1/token", json={"token": "test", "user": "admin", "admin": True}, status=200) sk_task = SklearnJob() responses.add(responses.POST, "http://127.0.0.1:9186/api/v1/jobs", json={"message": "update job {0} successfully".format(sk_task.name)}, status=202) responses.add(responses.GET, "http://127.0.0.1:9186/api/v1/jobs/{0}".format(sk_task.name), json={}, status=404) responses.add(responses.GET, "http://127.0.0.1:9186/api/v1/jobs/{0}".format(sk_task.name), body='{"jobStatus": {"state":"SUCCEED"}}', status=200) success = luigi.build([sk_task], local_scheduler=True) self.assertTrue(success) self.assertTrue(sk_task.complete()) @responses.activate def test_fail(self): """ Here using the responses lib to mock the PAI rest api call, the following specify the response of the call. """ responses.add(responses.POST, "http://127.0.0.1:9186/api/v1/token", json={"token": "test", "user": "admin", "admin": True}, status=200) fail_task = SklearnJob() responses.add(responses.POST, "http://127.0.0.1:9186/api/v1/jobs", json={"message": "update job {0} successfully".format(fail_task.name)}, status=202) responses.add(responses.GET, "http://127.0.0.1:9186/api/v1/jobs/{0}".format(fail_task.name), json={}, status=404) responses.add(responses.GET, "http://127.0.0.1:9186/api/v1/jobs/{0}".format(fail_task.name), body='{"jobStatus": {"state":"FAILED"}}', status=200) success = luigi.build([fail_task], local_scheduler=True) self.assertFalse(success) self.assertFalse(fail_task.complete()) ================================================ FILE: test/contrib/pig_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import subprocess import tempfile import pytest from helpers import unittest from mock import patch import luigi from luigi.contrib.pig import PigJobError, PigJobTask class SimpleTestJob(PigJobTask): def output(self): return luigi.LocalTarget("simple-output") def pig_script_path(self): return "my_simple_pig_script.pig" class ComplexTestJob(PigJobTask): def output(self): return luigi.LocalTarget("complex-output") def pig_script_path(self): return "my_complex_pig_script.pig" def pig_env_vars(self): return {"PIG_CLASSPATH": "/your/path"} def pig_properties(self): return {"pig.additional.jars": "/path/to/your/jar"} def pig_parameters(self): return {"YOUR_PARAM_NAME": "Your param value"} def pig_options(self): return ["-x", "local"] @pytest.mark.apache class SimplePigTest(unittest.TestCase): def setUp(self): pass def tearDown(self): pass @patch("subprocess.Popen") def test_run__success(self, mock): arglist_result = [] p = subprocess.Popen subprocess.Popen = _get_fake_Popen(arglist_result, 0) try: job = SimpleTestJob() job.run() self.assertEqual([["/usr/share/pig/bin/pig", "-f", "my_simple_pig_script.pig"]], arglist_result) finally: subprocess.Popen = p @patch("subprocess.Popen") def test_run__fail(self, mock): arglist_result = [] p = subprocess.Popen subprocess.Popen = _get_fake_Popen(arglist_result, 1) try: job = SimpleTestJob() job.run() self.assertEqual([["/usr/share/pig/bin/pig", "-f", "my_simple_pig_script.pig"]], arglist_result) except PigJobError as e: p = e self.assertEqual("stderr", p.err) else: self.fail("Should have thrown PigJobError") finally: subprocess.Popen = p @pytest.mark.apache class ComplexPigTest(unittest.TestCase): def setUp(self): pass def tearDown(self): pass @patch("subprocess.Popen") def test_run__success(self, mock): arglist_result = [] p = subprocess.Popen subprocess.Popen = _get_fake_Popen(arglist_result, 0) with ( tempfile.NamedTemporaryFile(delete=False) as param_file_mock, tempfile.NamedTemporaryFile(delete=False) as prop_file_mock, patch("luigi.contrib.pig.tempfile.NamedTemporaryFile", side_effect=[param_file_mock, prop_file_mock]), ): try: job = ComplexTestJob() job.run() self.assertEqual( [ [ "/usr/share/pig/bin/pig", "-x", "local", "-param_file", param_file_mock.name, "-propertyFile", prop_file_mock.name, "-f", "my_complex_pig_script.pig", ] ], arglist_result, ) # Check param file with open(param_file_mock.name) as pparams_file: pparams = pparams_file.readlines() self.assertEqual(1, len(pparams)) self.assertEqual("YOUR_PARAM_NAME=Your param value\n", pparams[0]) # Check property file with open(prop_file_mock.name) as pprops_file: pprops = pprops_file.readlines() self.assertEqual(1, len(pprops)) self.assertEqual("pig.additional.jars=/path/to/your/jar\n", pprops[0]) finally: subprocess.Popen = p @patch("subprocess.Popen") def test_run__fail(self, mock): arglist_result = [] p = subprocess.Popen subprocess.Popen = _get_fake_Popen(arglist_result, 1) with ( tempfile.NamedTemporaryFile(delete=False) as param_file_mock, tempfile.NamedTemporaryFile(delete=False) as prop_file_mock, patch("luigi.contrib.pig.tempfile.NamedTemporaryFile", side_effect=[param_file_mock, prop_file_mock]), ): try: job = ComplexTestJob() job.run() except PigJobError as e: p = e self.assertEqual("stderr", p.err) self.assertEqual( [ [ "/usr/share/pig/bin/pig", "-x", "local", "-param_file", param_file_mock.name, "-propertyFile", prop_file_mock.name, "-f", "my_complex_pig_script.pig", ] ], arglist_result, ) # Check param file with open(param_file_mock.name) as pparams_file: pparams = pparams_file.readlines() self.assertEqual(1, len(pparams)) self.assertEqual("YOUR_PARAM_NAME=Your param value\n", pparams[0]) # Check property file with open(prop_file_mock.name) as pprops_file: pprops = pprops_file.readlines() self.assertEqual(1, len(pprops)) self.assertEqual("pig.additional.jars=/path/to/your/jar\n", pprops[0]) else: self.fail("Should have thrown PigJobError") finally: subprocess.Popen = p def _get_fake_Popen(arglist_result, return_code, *args, **kwargs): def Popen_fake(arglist, shell=None, stdout=None, stderr=None, env=None, close_fds=True): arglist_result.append(arglist) class P: number_of_process_polls = 5 def __init__(self): self._process_polls_left = self.number_of_process_polls def wait(self): pass def poll(self): if self._process_polls_left: self._process_polls_left -= 1 return None return 0 def communicate(self): return "end" def env(self): return self.env p = P() p.returncode = return_code p.stderr = tempfile.TemporaryFile() p.stdout = tempfile.TemporaryFile() p.stdout.write(b"stdout") p.stderr.write(b"stderr") # Reset temp files so the output can be read. p.stdout.seek(0) p.stderr.seek(0) return p return Popen_fake ================================================ FILE: test/contrib/postgres_test.py ================================================ # Copyright (c) 2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); you may not # use this file except in compliance with the License. You may obtain a copy of # the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations under # the License. import datetime import mock import pytest from helpers import unittest import luigi import luigi.contrib.postgres from luigi.tools.range import RangeDaily def datetime_to_epoch(dt): td = dt - datetime.datetime(1970, 1, 1) return td.days * 86400 + td.seconds + td.microseconds / 1e6 class MockPostgresCursor(mock.Mock): """ Keeps state to simulate executing SELECT queries and fetching results. """ def __init__(self, existing_update_ids): super(MockPostgresCursor, self).__init__() self.existing = existing_update_ids def execute(self, query, params): if query.startswith("SELECT 1 FROM table_updates"): self.fetchone_result = (1,) if params[0] in self.existing else None else: self.fetchone_result = None def fetchone(self): return self.fetchone_result class DummyPostgresImporter(luigi.contrib.postgres.CopyToTable): date = luigi.DateParameter() host = "dummy_host" database = "dummy_database" user = "dummy_user" password = "dummy_password" table = "dummy_table" columns = ( ("some_text", "text"), ("some_int", "int"), ) @pytest.mark.postgres class DailyCopyToTableTest(unittest.TestCase): maxDiff = None @mock.patch("psycopg2.connect") def test_bulk_complete(self, mock_connect): mock_cursor = MockPostgresCursor([DummyPostgresImporter(date=datetime.datetime(2015, 1, 3)).task_id]) mock_connect.return_value.cursor.return_value = mock_cursor task = RangeDaily(of=DummyPostgresImporter, start=datetime.date(2015, 1, 2), now=datetime_to_epoch(datetime.datetime(2015, 1, 7))) actual = sorted([t.task_id for t in task.requires()]) self.assertEqual( actual, sorted( [ DummyPostgresImporter(date=datetime.datetime(2015, 1, 2)).task_id, DummyPostgresImporter(date=datetime.datetime(2015, 1, 4)).task_id, DummyPostgresImporter(date=datetime.datetime(2015, 1, 5)).task_id, DummyPostgresImporter(date=datetime.datetime(2015, 1, 6)).task_id, ] ), ) self.assertFalse(task.complete()) class DummyPostgresQuery(luigi.contrib.postgres.PostgresQuery): date = luigi.DateParameter() host = "dummy_host" database = "dummy_database" user = "dummy_user" password = "dummy_password" table = "dummy_table" columns = ( ("some_text", "text"), ("some_int", "int"), ) query = "SELECT * FROM foo" class DummyPostgresQueryWithPort(DummyPostgresQuery): port = 1234 class DummyPostgresQueryWithPortEncodedInHost(DummyPostgresQuery): host = "dummy_host:1234" @pytest.mark.postgres class PostgresQueryTest(unittest.TestCase): maxDiff = None @mock.patch("psycopg2.connect") def test_bulk_complete(self, mock_connect): mock_cursor = MockPostgresCursor(["DummyPostgresQuery_2015_01_03_838e32a989"]) mock_connect.return_value.cursor.return_value = mock_cursor task = RangeDaily(of=DummyPostgresQuery, start=datetime.date(2015, 1, 2), now=datetime_to_epoch(datetime.datetime(2015, 1, 7))) actual = [t.task_id for t in task.requires()] self.assertEqual( actual, [ "DummyPostgresQuery_2015_01_02_3a0ec498ed", "DummyPostgresQuery_2015_01_04_9c1d42ff62", "DummyPostgresQuery_2015_01_05_0f90e52357", "DummyPostgresQuery_2015_01_06_f91a47ec40", ], ) self.assertFalse(task.complete()) def test_override_port(self): output = DummyPostgresQueryWithPort(date=datetime.datetime(1991, 3, 24)).output() self.assertEqual(output.port, 1234) def test_port_encoded_in_host(self): output = DummyPostgresQueryWithPortEncodedInHost(date=datetime.datetime(1991, 3, 24)).output() self.assertEqual(output.port, "1234") @pytest.mark.postgres class TestCopyToTableWithMetaColumns(unittest.TestCase): @mock.patch("luigi.contrib.postgres.CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) @mock.patch("luigi.contrib.postgres.CopyToTable._add_metadata_columns") @mock.patch("luigi.contrib.postgres.CopyToTable.post_copy_metacolumns") @mock.patch("luigi.contrib.postgres.CopyToTable.rows", return_value=["row1", "row2"]) @mock.patch("luigi.contrib.postgres.PostgresTarget") @mock.patch("psycopg2.connect") def test_copy_with_metadata_columns_enabled( self, mock_connect, mock_redshift_target, mock_rows, mock_add_columns, mock_update_columns, mock_metadata_columns_enabled ): task = DummyPostgresImporter(date=datetime.datetime(1991, 3, 24)) mock_cursor = MockPostgresCursor([task.task_id]) mock_connect.return_value.cursor.return_value = mock_cursor task = DummyPostgresImporter(date=datetime.datetime(1991, 3, 24)) task.run() self.assertTrue(mock_add_columns.called) self.assertTrue(mock_update_columns.called) @mock.patch("luigi.contrib.postgres.CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=False) @mock.patch("luigi.contrib.postgres.CopyToTable._add_metadata_columns") @mock.patch("luigi.contrib.postgres.CopyToTable.post_copy_metacolumns") @mock.patch("luigi.contrib.postgres.CopyToTable.rows", return_value=["row1", "row2"]) @mock.patch("luigi.contrib.postgres.PostgresTarget") @mock.patch("psycopg2.connect") def test_copy_with_metadata_columns_disabled( self, mock_connect, mock_redshift_target, mock_rows, mock_add_columns, mock_update_columns, mock_metadata_columns_enabled ): task = DummyPostgresImporter(date=datetime.datetime(1991, 3, 24)) mock_cursor = MockPostgresCursor([task.task_id]) mock_connect.return_value.cursor.return_value = mock_cursor task.run() self.assertFalse(mock_add_columns.called) self.assertFalse(mock_update_columns.called) ================================================ FILE: test/contrib/postgres_with_server_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import os import pytest from helpers import unittest import luigi import luigi.notifications from luigi.contrib import postgres """ Typical use cases that should be tested: * Daily overwrite of all data in table * Daily inserts of new segment in table * (Daily insertion/creation of new table) * Daily insertion of multiple (different) new segments into table """ host = "localhost" database = "spotify" user = os.getenv("POSTGRES_USER", "spotify") password = "guest" try: import psycopg2 conn = psycopg2.connect( user=user, host=host, database=database, password=password, ) conn.close() psycopg2.extensions.register_type(psycopg2.extensions.UNICODE) psycopg2.extensions.register_type(psycopg2.extensions.UNICODEARRAY) except Exception: raise unittest.SkipTest("Unable to connect to postgres") # to avoid copying: class CopyToTestDB(postgres.CopyToTable): host = host database = database user = user password = password class TestPostgresTask(CopyToTestDB): table = "test_table" columns = (("test_text", "text"), ("test_int", "int"), ("test_float", "float")) def create_table(self, connection): connection.cursor().execute("CREATE TABLE {table} (id SERIAL PRIMARY KEY, test_text TEXT, test_int INT, test_float FLOAT)".format(table=self.table)) def rows(self): yield "foo", 123, 123.45 yield None, "-100", "5143.213" yield "\t\n\r\\N", 0, 0 yield "éцү我", 0, 0 yield "", 0, r"\N" # Test working default null charcter class MetricBase(CopyToTestDB): table = "metrics" columns = [("metric", "text"), ("value", "int")] class Metric1(MetricBase): param = luigi.Parameter() def rows(self): yield "metric1", 1 yield "metric1", 2 yield "metric1", 3 class Metric2(MetricBase): param = luigi.Parameter() def rows(self): yield "metric2", 1 yield "metric2", 4 yield "metric2", 3 @pytest.mark.postgres class TestPostgresImportTask(unittest.TestCase): def test_default_escape(self): self.assertEqual(postgres.default_escape("foo"), "foo") self.assertEqual(postgres.default_escape("\n"), "\\n") self.assertEqual(postgres.default_escape("\\\n"), "\\\\\\n") self.assertEqual(postgres.default_escape("\n\r\\\t\\N\\"), "\\n\\r\\\\\\t\\\\N\\\\") def test_repeat(self): task = TestPostgresTask() conn = task.output().connect() conn.autocommit = True cursor = conn.cursor() cursor.execute("DROP TABLE IF EXISTS {table}".format(table=task.table)) cursor.execute("DROP TABLE IF EXISTS {marker_table}".format(marker_table=postgres.PostgresTarget.marker_table)) luigi.build([task], local_scheduler=True) luigi.build([task], local_scheduler=True) # try to schedule twice cursor.execute("""SELECT test_text, test_int, test_float FROM test_table ORDER BY id ASC""") rows = tuple(cursor) self.assertEqual( rows, ( ("foo", 123, 123.45), (None, -100, 5143.213), ("\t\n\r\\N", 0.0, 0), ("éцү我", 0, 0), ("", 0, None), # Test working default null charcter ), ) def test_multimetric(self): metrics = MetricBase() conn = metrics.output().connect() conn.autocommit = True conn.cursor().execute("DROP TABLE IF EXISTS {table}".format(table=metrics.table)) conn.cursor().execute("DROP TABLE IF EXISTS {marker_table}".format(marker_table=postgres.PostgresTarget.marker_table)) luigi.build([Metric1(20), Metric1(21), Metric2("foo")], local_scheduler=True) cursor = conn.cursor() cursor.execute("select count(*) from {table}".format(table=metrics.table)) self.assertEqual(tuple(cursor), ((9,),)) def test_clear(self): class Metric2Copy(Metric2): def init_copy(self, connection): query = "TRUNCATE {0}".format(self.table) connection.cursor().execute(query) clearer = Metric2Copy(21) conn = clearer.output().connect() conn.autocommit = True conn.cursor().execute("DROP TABLE IF EXISTS {table}".format(table=clearer.table)) conn.cursor().execute("DROP TABLE IF EXISTS {marker_table}".format(marker_table=postgres.PostgresTarget.marker_table)) luigi.build([Metric1(0), Metric1(1)], local_scheduler=True) luigi.build([clearer], local_scheduler=True) cursor = conn.cursor() cursor.execute("select count(*) from {table}".format(table=clearer.table)) self.assertEqual(tuple(cursor), ((3,),)) ================================================ FILE: test/contrib/presto_test.py ================================================ import unittest import mock from pyhive.exc import DatabaseError from pyhive.presto import Connection, Cursor from luigi.contrib.presto import PrestoClient, PrestoTarget, PrestoTask class WithPrestoClientTest(unittest.TestCase): def test_creates_client_with_expected_params(self): # arrange class _Task(PrestoTask): host = "127.0.0.1" port = 8089 user = "user_123" database = "db1" table = "tbl1" expected_connection_kwargs = { "host": "127.0.0.1", "port": 8089, "username": "user_123", "catalog": "hive", "protocol": "https", "source": "pyhive", "poll_interval": 1.0, "schema": "db1", "requests_kwargs": {"verify": False}, } # act task = _Task() # assert client = task._client assert isinstance(client, PrestoClient) connection = client._connection assert not connection._args assert connection._kwargs == expected_connection_kwargs class PrestoClientTest(unittest.TestCase): @mock.patch("luigi.contrib.presto.sleep", return_value=None) def test_watch(self, sleep): # arrange status = {"stats": {"progressPercentage": 1.2}, "infoUri": "http://127.0.0.1:8080/ui/query.html?query=123"} cursor = mock.MagicMock(spec=Cursor) cursor.poll.side_effect = [status, None] connection = mock.MagicMock(spec=Connection) connection.cursor.return_value = cursor client = PrestoClient(connection) query = "select 1" # act statuses = list(client.execute(query)) # assert assert client.percentage_progress == 1.2 assert client.info_uri == "http://127.0.0.1:8080/ui/query.html?query=123" assert statuses == [status] cursor.execute.assert_called_once_with(query, None) cursor.close.assert_called_once_with() @mock.patch("luigi.contrib.presto.sleep", return_value=None) def test_fetch(self, sleep): # arrange status = {"infoUri": "http://127.0.0.1:8080/ui/query.html?query=123"} cursor = mock.MagicMock(spec=Cursor) cursor.poll.side_effect = [status, None] cursor.fetchall.return_value = [(1,), (2,)] connection = mock.MagicMock(spec=Connection) connection.cursor.return_value = cursor client = PrestoClient(connection) query = "select 1" # act result = list(client.execute(query, mode="fetch")) # assert assert client.percentage_progress == 0.1 assert client.info_uri == "http://127.0.0.1:8080/ui/query.html?query=123" cursor.execute.assert_called_once_with(query, None) cursor.close.assert_called_once_with() assert result == [(1,), (2,)] class PrestoTargetTest(unittest.TestCase): def test_non_partitioned(self): # arrange client = mock.MagicMock(spec=PrestoClient) client.execute.return_value = iter( [ (7, None), ] ) catalog = "hive" database = "schm1" table = "tbl1" # act target = PrestoTarget(client, catalog, database, table) count = target.count() exists = target.exists() # assert client.execute.assert_called_once_with( "SELECT COUNT(*) AS cnt FROM hive.schm1.tbl1 WHERE 1 = %s LIMIT 1", [ 1, ], mode="fetch", ) assert count == 7 assert exists def test_partitioned(self): # arrange client = mock.MagicMock(spec=PrestoClient) client.execute.return_value = iter( [ (7, None), ] ) catalog = "hive" database = "schm1" table = "tbl1" partition = {"a": 2, "b": "x"} # act target = PrestoTarget(client, catalog, database, table, partition) count = target.count() exists = target.exists() # assert client.execute.assert_called_once_with("SELECT COUNT(*) AS cnt FROM hive.schm1.tbl1 WHERE a = %s AND b = %s LIMIT 1", [2, "x"], mode="fetch") assert count == 7 assert exists def test_table_doesnot_exist(self): # arrange e = DatabaseError() setattr(e, "message", {"message": "line 1:15: Table hive.schm1.tbl1 does not exist"}) client = mock.MagicMock(spec=PrestoClient) client.execute.side_effect = e catalog = "hive" database = "schm1" table = "tbl1" # act target = PrestoTarget(client, catalog, database, table) exists = target.exists() # assert client.execute.assert_called_once_with("SELECT COUNT(*) AS cnt FROM hive.schm1.tbl1 WHERE 1 = %s LIMIT 1", [1], mode="fetch") assert not exists class PrestoTest(unittest.TestCase): @mock.patch("luigi.contrib.presto.sleep", return_value=None) def test_run(self, sleep): # arrange client = mock.MagicMock(spec=PrestoClient) client.execute.return_value = [(), (), ()] client.info_uri = "http://127.0.0.1:8080/ui/query.html?query=123" client.percentage_progress = 2.3 class _Task(PrestoTask): host = "127.0.0.1" port = 8089 user = "user_123" password = "123" database = "db1" table = "tbl1" query = "select 1" # act with mock.patch("luigi.contrib.presto.PrestoClient", return_value=client): task = _Task() task.set_progress_percentage = mock.MagicMock() task.set_tracking_url = mock.MagicMock() task.run() # assert assert task.protocol == "https" assert task.output().catalog == "hive" assert task.output().database == "db1" assert task.output().table == "tbl1" assert task.output().partition is None client.execute.assert_called_once_with("select 1") task.set_tracking_url.assert_called_once_with("http://127.0.0.1:8080/ui/query.html?query=123") assert task.set_progress_percentage.mock_calls == [ mock.call(2.3), mock.call(2.3), mock.call(2.3), ] ================================================ FILE: test/contrib/prometheus_metric_test.py ================================================ import pytest from helpers import unittest from prometheus_client import CONTENT_TYPE_LATEST from luigi.contrib.prometheus_metric import PrometheusMetricsCollector from luigi.metrics import MetricsCollectors from luigi.scheduler import Scheduler try: from unittest import mock except ImportError: import mock WORKER = "myworker" TASK_ID = "TaskID" TASK_FAMILY = "TaskFamily" A_PARAM_VALUE = "1" B_PARAM_VALUE = "2" C_PARAM_VALUE = "3" @pytest.mark.contrib class PrometheusMetricBaseTest(unittest.TestCase): COLLECTOR_KWARGS = {} EXPECTED_LABELS = {"family": TASK_FAMILY} def setUp(self): self.collector = PrometheusMetricsCollector(**self.COLLECTOR_KWARGS) self.s = Scheduler(metrics_collector=MetricsCollectors.prometheus) self.gauge_name = "luigi_task_execution_time_seconds" def startTask(self): self.s.add_task( worker=WORKER, task_id=TASK_ID, family=TASK_FAMILY, params={"a": A_PARAM_VALUE, "b": B_PARAM_VALUE, "c": C_PARAM_VALUE}, ) task = self.s._state.get_task(TASK_ID) task.time_running = 0 task.updated = 5 return task def test_handle_task_started(self): task = self.startTask() self.collector.handle_task_started(task) counter_name = "luigi_task_started_total" gauge_name = self.gauge_name labels = self.EXPECTED_LABELS assert self.collector.registry.get_sample_value(counter_name, labels=labels) == 1 assert self.collector.registry.get_sample_value(gauge_name, labels=labels) == 0 def test_handle_task_failed(self): task = self.startTask() self.collector.handle_task_failed(task) counter_name = "luigi_task_failed_total" gauge_name = self.gauge_name labels = self.EXPECTED_LABELS assert self.collector.registry.get_sample_value(counter_name, labels=labels) == 1 assert self.collector.registry.get_sample_value(gauge_name, labels=labels) == task.updated - task.time_running def test_handle_task_disabled(self): task = self.startTask() self.collector.handle_task_disabled(task, self.s._config) counter_name = "luigi_task_disabled_total" gauge_name = self.gauge_name labels = self.EXPECTED_LABELS assert self.collector.registry.get_sample_value(counter_name, labels=labels) == 1 assert self.collector.registry.get_sample_value(gauge_name, labels=labels) == task.updated - task.time_running def test_handle_task_done(self): task = self.startTask() self.collector.handle_task_done(task) counter_name = "luigi_task_done_total" gauge_name = self.gauge_name labels = self.EXPECTED_LABELS assert self.collector.registry.get_sample_value(counter_name, labels=labels) == 1 assert self.collector.registry.get_sample_value(gauge_name, labels=labels) == task.updated - task.time_running def test_configure_http_handler(self): mock_http_handler = mock.MagicMock() self.collector.configure_http_handler(mock_http_handler) mock_http_handler.set_header.assert_called_once_with("Content-Type", CONTENT_TYPE_LATEST) @pytest.mark.contrib class PrometheusMetricTaskParamsOnlyTest(PrometheusMetricBaseTest): COLLECTOR_KWARGS = { "use_task_family_in_labels": False, "task_parameters_to_use_in_labels": ["a", "c"], } EXPECTED_LABELS = {"a": A_PARAM_VALUE, "c": C_PARAM_VALUE} @pytest.mark.contrib class PrometheusMetricTaskFamilyAndTaskParamsTest(PrometheusMetricBaseTest): COLLECTOR_KWARGS = { "use_task_family_in_labels": True, "task_parameters_to_use_in_labels": ["b"], } EXPECTED_LABELS = {"family": TASK_FAMILY, "b": B_PARAM_VALUE} ================================================ FILE: test/contrib/rdbms_test.py ================================================ # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ We're using Redshift as the test bed since Redshift implements RDBMS. We could have opted for PSQL but we're less familiar with that contrib and there are less examples on how to test it. """ import unittest import mock import pytest import luigi import luigi.contrib.redshift # Fake AWS and S3 credentials taken from `../redshift_test.py`. AWS_ACCESS_KEY = "key" AWS_SECRET_KEY = "secret" AWS_ACCOUNT_ID = "0123456789012" AWS_ROLE_NAME = "MyRedshiftRole" BUCKET = "bucket" KEY = "key" class DummyS3CopyToTableBase(luigi.contrib.redshift.S3CopyToTable): # Class attributes taken from `DummyPostgresImporter` in # `../postgres_test.py`. host = "dummy_host" database = "dummy_database" user = "dummy_user" password = "dummy_password" table = luigi.Parameter(default="dummy_table") columns = luigi.TupleParameter( default=( ("some_text", "varchar(255)"), ("some_int", "int"), ) ) copy_options = "" prune_table = "" prune_column = "" prune_date = "" def s3_load_path(self): return "s3://%s/%s" % (BUCKET, KEY) class DummyS3CopyToTableKey(DummyS3CopyToTableBase): aws_access_key_id = AWS_ACCESS_KEY aws_secret_access_key = AWS_SECRET_KEY @pytest.mark.aws class TestS3CopyToTableWithMetaColumns(unittest.TestCase): @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) @mock.patch("luigi.contrib.redshift.S3CopyToTable.metadata_columns", new_callable=mock.PropertyMock, return_value=[("created_tz", "TIMESTAMP")]) @mock.patch("luigi.contrib.redshift.RedshiftTarget") def test_copy_check_meta_columns_to_table_if_exists(self, mock_redshift_target, mock_metadata_columns, mock_metadata_columns_enabled): task = DummyS3CopyToTableKey(table="my_test_table") task.run() mock_cursor = mock_redshift_target.return_value.connect.return_value.cursor.return_value executed_query = mock_cursor.execute.call_args_list[1][0][0] expected_output = ( "SELECT 1 AS column_exists FROM information_schema.columns WHERE table_name = LOWER('{table}') AND column_name = LOWER('{column}') LIMIT 1;".format( table="my_test_table", column="created_tz" ) ) self.assertEqual(executed_query, expected_output) @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) @mock.patch("luigi.contrib.redshift.S3CopyToTable.metadata_columns", new_callable=mock.PropertyMock, return_value=[("created_tz", "TIMESTAMP")]) @mock.patch("luigi.contrib.redshift.RedshiftTarget") def test_copy_check_meta_columns_to_schematable_if_exists(self, mock_redshift_target, mock_metadata_columns, mock_metadata_columns_enabled): task = DummyS3CopyToTableKey(table="test.my_test_table") task.run() mock_cursor = mock_redshift_target.return_value.connect.return_value.cursor.return_value executed_query = mock_cursor.execute.call_args_list[2][0][0] expected_output = ( "SELECT 1 AS column_exists FROM information_schema.columns " "WHERE table_schema = LOWER('{schema}') " "AND table_name = LOWER('{table}') " "AND column_name = LOWER('{column}') " "LIMIT 1;".format(schema="test", table="my_test_table", column="created_tz") ) self.assertEqual(executed_query, expected_output) @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) @mock.patch("luigi.contrib.redshift.S3CopyToTable.metadata_columns", new_callable=mock.PropertyMock, return_value=[("created_tz", "TIMESTAMP")]) @mock.patch("luigi.contrib.redshift.S3CopyToTable._column_exists", return_value=True) @mock.patch("luigi.contrib.redshift.S3CopyToTable._add_column_to_table") @mock.patch("luigi.contrib.redshift.RedshiftTarget") def test_copy_not_add_if_meta_columns_already_exists( self, mock_redshift_target, mock_add_to_table, mock_columns_exists, mock_metadata_columns, mock_metadata_columns_enabled ): task = DummyS3CopyToTableKey() task.run() self.assertFalse(mock_add_to_table.called) @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) @mock.patch("luigi.contrib.redshift.S3CopyToTable.metadata_columns", new_callable=mock.PropertyMock, return_value=[("created_tz", "TIMESTAMP")]) @mock.patch("luigi.contrib.redshift.S3CopyToTable._column_exists", return_value=False) @mock.patch("luigi.contrib.redshift.S3CopyToTable._add_column_to_table") @mock.patch("luigi.contrib.redshift.RedshiftTarget") def test_copy_add_if_meta_columns_not_already_exists( self, mock_redshift_target, mock_add_to_table, mock_columns_exists, mock_metadata_columns, mock_metadata_columns_enabled ): task = DummyS3CopyToTableKey() task.run() self.assertTrue(mock_add_to_table.called) @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) @mock.patch("luigi.contrib.redshift.S3CopyToTable.metadata_columns", new_callable=mock.PropertyMock, return_value=[("created_tz", "TIMESTAMP")]) @mock.patch("luigi.contrib.redshift.S3CopyToTable._column_exists", return_value=False) @mock.patch("luigi.contrib.redshift.RedshiftTarget") def test_copy_add_regular_column(self, mock_redshift_target, mock_columns_exists, mock_metadata_columns, mock_metadata_columns_enabled): task = DummyS3CopyToTableKey(table="my_test_table") task.run() mock_cursor = mock_redshift_target.return_value.connect.return_value.cursor.return_value executed_query = mock_cursor.execute.call_args_list[1][0][0] expected_output = "ALTER TABLE {table} ADD COLUMN {column} {type};".format(table="my_test_table", column="created_tz", type="TIMESTAMP") self.assertEqual(executed_query, expected_output) @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) @mock.patch("luigi.contrib.redshift.S3CopyToTable.metadata_columns", new_callable=mock.PropertyMock, return_value=[("created_tz", "TIMESTAMP", "bytedict")]) @mock.patch("luigi.contrib.redshift.S3CopyToTable._column_exists", return_value=False) @mock.patch("luigi.contrib.redshift.RedshiftTarget") def test_copy_add_encoded_column(self, mock_redshift_target, mock_columns_exists, mock_metadata_columns, mock_metadata_columns_enabled): task = DummyS3CopyToTableKey(table="my_test_table") task.run() mock_cursor = mock_redshift_target.return_value.connect.return_value.cursor.return_value executed_query = mock_cursor.execute.call_args_list[1][0][0] expected_output = "ALTER TABLE {table} ADD COLUMN {column} {type} ENCODE {encoding};".format( table="my_test_table", column="created_tz", type="TIMESTAMP", encoding="bytedict" ) self.assertEqual(executed_query, expected_output) @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) @mock.patch("luigi.contrib.redshift.S3CopyToTable.metadata_columns", new_callable=mock.PropertyMock, return_value=[("created_tz")]) @mock.patch("luigi.contrib.redshift.S3CopyToTable._column_exists", return_value=False) @mock.patch("luigi.contrib.redshift.RedshiftTarget") def test_copy_raise_error_on_no_column_type(self, mock_redshift_target, mock_columns_exists, mock_metadata_columns, mock_metadata_columns_enabled): task = DummyS3CopyToTableKey() with self.assertRaises(ValueError): task.run() @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) @mock.patch( "luigi.contrib.redshift.S3CopyToTable.metadata_columns", new_callable=mock.PropertyMock, return_value=[("created_tz", "TIMESTAMP", "bytedict", "42")] ) @mock.patch("luigi.contrib.redshift.S3CopyToTable._column_exists", return_value=False) @mock.patch("luigi.contrib.redshift.RedshiftTarget") def test_copy_raise_error_on_invalid_column(self, mock_redshift_target, mock_columns_exists, mock_metadata_columns, mock_metadata_columns_enabled): task = DummyS3CopyToTableKey() with self.assertRaises(ValueError): task.run() @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) @mock.patch("luigi.contrib.redshift.S3CopyToTable.metadata_queries", new_callable=mock.PropertyMock, return_value=["SELECT 1 FROM X", "SELECT 2 FROM Y"]) @mock.patch("luigi.contrib.redshift.RedshiftTarget") def test_post_copy_metacolumns(self, mock_redshift_target, mock_metadata_queries, mock_metadata_columns_enabled): task = DummyS3CopyToTableKey() task.run() mock_cursor = mock_redshift_target.return_value.connect.return_value.cursor.return_value executed_query = mock_cursor.execute.call_args_list[2][0][0] expected_output = "SELECT 1 FROM X" self.assertEqual(executed_query, expected_output) executed_query = mock_cursor.execute.call_args_list[3][0][0] expected_output = "SELECT 2 FROM Y" self.assertEqual(executed_query, expected_output) ================================================ FILE: test/contrib/redis_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # pylint: disable=F0401 from time import sleep import pytest from helpers import unittest try: import redis except ImportError: raise unittest.SkipTest("Unable to load redis module") from luigi.contrib.redis_store import RedisTarget HOST = "localhost" PORT = 6379 DB = 15 PASSWORD = None SOCKET_TIMEOUT = None MARKER_PREFIX = "luigi_test" EXPIRE = 5 @pytest.mark.contrib class RedisTargetTest(unittest.TestCase): """Test touch, exists and target expiration""" def test_touch_and_exists(self): target = RedisTarget(HOST, PORT, DB, "update_id", PASSWORD) target.marker_prefix = MARKER_PREFIX flush() self.assertFalse(target.exists(), "Target should not exist before touching it") target.touch() self.assertTrue(target.exists(), "Target should exist after touching it") flush() def test_expiration(self): target = RedisTarget(HOST, PORT, DB, "update_id", PASSWORD, None, EXPIRE) target.marker_prefix = MARKER_PREFIX flush() target.touch() self.assertTrue(target.exists(), "Target should exist after touching it and before expiring") sleep(EXPIRE) self.assertFalse(target.exists(), "Target should not exist after expiring") flush() def flush(): """Flush test DB""" redis_client = redis.StrictRedis(host=HOST, port=PORT, db=DB, socket_timeout=SOCKET_TIMEOUT) redis_client.flushdb() ================================================ FILE: test/contrib/redshift_test.py ================================================ # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import os import sys import mock import pytest from helpers import unittest, with_config from moto import mock_s3 import luigi import luigi.contrib.redshift import luigi.notifications from luigi.contrib import redshift from luigi.contrib.s3 import S3Client if (3, 4, 0) <= sys.version_info[:3] < (3, 4, 3): # spulec/moto#308 mock_s3 = unittest.skip("moto mock doesn't work with python3.4") # NOQA # Fake AWS and S3 credentials taken from `../redshift_test.py`. AWS_ACCESS_KEY = "key" AWS_SECRET_KEY = "secret" AWS_ACCOUNT_ID = "0123456789012" AWS_ROLE_NAME = "MyRedshiftRole" BUCKET = "bucket" KEY = "key" KEY_2 = "key2" FILES = ["file1", "file2", "file3"] def generate_manifest_json(path_to_folders, file_names): entries = [] for path_to_folder in path_to_folders: for file_name in file_names: entries.append({"url": "%s/%s" % (path_to_folder, file_name), "mandatory": True}) return {"entries": entries} class DummyS3CopyToTableBase(luigi.contrib.redshift.S3CopyToTable): # Class attributes taken from `DummyPostgresImporter` in # `../postgres_test.py`. host = "dummy_host" database = "dummy_database" user = "dummy_user" password = "dummy_password" table = luigi.Parameter(default="dummy_table") columns = luigi.TupleParameter( default=( ("some_text", "varchar(255)"), ("some_int", "int"), ) ) table_constraints = luigi.Parameter(default="") copy_options = "" prune_table = "" prune_column = "" prune_date = "" def s3_load_path(self): return "s3://%s/%s" % (BUCKET, KEY) class DummyS3CopyJSONToTableBase(luigi.contrib.redshift.S3CopyJSONToTable): # Class attributes taken from `DummyPostgresImporter` in # `../postgres_test.py`. aws_access_key_id = AWS_ACCESS_KEY aws_secret_access_key = AWS_SECRET_KEY host = "dummy_host" database = "dummy_database" user = "dummy_user" password = "dummy_password" table = luigi.Parameter(default="dummy_table") columns = luigi.TupleParameter( default=( ("some_text", "varchar(255)"), ("some_int", "int"), ) ) copy_options = "" prune_table = "" prune_column = "" prune_date = "" jsonpath = "" copy_json_options = "" def s3_load_path(self): return "s3://%s/%s" % (BUCKET, KEY) class DummyS3CopyToTableKey(DummyS3CopyToTableBase): aws_access_key_id = AWS_ACCESS_KEY aws_secret_access_key = AWS_SECRET_KEY class DummyS3CopyToTableWithCompressionEncodings(DummyS3CopyToTableKey): columns = ( ("some_text", "varchar(255)", "LZO"), ("some_int", "int", "DELTA"), ) class DummyS3CopyToTableRole(DummyS3CopyToTableBase): aws_account_id = AWS_ACCESS_KEY aws_arn_role_name = AWS_SECRET_KEY class DummyS3CopyToTempTable(DummyS3CopyToTableKey): # Extend/alter DummyS3CopyToTable for temp table copying table = luigi.Parameter(default="stage_dummy_table") table_type = "TEMP" prune_date = "current_date - 30" prune_column = "dumb_date" prune_table = "stage_dummy_table" queries = ["insert into dummy_table select * from stage_dummy_table;"] @pytest.mark.aws class TestInternalCredentials(unittest.TestCase, DummyS3CopyToTableKey): def test_from_property(self): self.assertEqual(self.aws_access_key_id, AWS_ACCESS_KEY) self.assertEqual(self.aws_secret_access_key, AWS_SECRET_KEY) @pytest.mark.aws class TestExternalCredentials(unittest.TestCase, DummyS3CopyToTableBase): @mock.patch.dict(os.environ, {"AWS_ACCESS_KEY_ID": "env_key", "AWS_SECRET_ACCESS_KEY": "env_secret"}) def test_from_env(self): self.assertEqual(self.aws_access_key_id, "env_key") self.assertEqual(self.aws_secret_access_key, "env_secret") @with_config({"redshift": {"aws_access_key_id": "config_key", "aws_secret_access_key": "config_secret"}}) def test_from_config(self): self.assertEqual(self.aws_access_key_id, "config_key") self.assertEqual(self.aws_secret_access_key, "config_secret") @pytest.mark.aws class TestS3CopyToTableWithMetaColumns(unittest.TestCase): @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) @mock.patch("luigi.contrib.redshift.S3CopyToTable._add_metadata_columns") @mock.patch("luigi.contrib.redshift.S3CopyToTable.post_copy_metacolumns") @mock.patch("luigi.contrib.redshift.RedshiftTarget") def test_copy_with_metadata_columns_enabled(self, mock_redshift_target, mock_add_columns, mock_update_columns, mock_metadata_columns_enabled): task = DummyS3CopyToTableKey() task.run() self.assertTrue(mock_add_columns.called) self.assertTrue(mock_update_columns.called) @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=False) @mock.patch("luigi.contrib.redshift.S3CopyToTable._add_metadata_columns") @mock.patch("luigi.contrib.redshift.S3CopyToTable.post_copy_metacolumns") @mock.patch("luigi.contrib.redshift.RedshiftTarget") def test_copy_with_metadata_columns_disabled(self, mock_redshift_target, mock_add_columns, mock_update_columns, mock_metadata_columns_enabled): task = DummyS3CopyToTableKey() task.run() self.assertFalse(mock_add_columns.called) self.assertFalse(mock_update_columns.called) @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) @mock.patch("luigi.contrib.redshift.S3CopyToTable._add_metadata_columns") @mock.patch("luigi.contrib.redshift.S3CopyToTable.post_copy_metacolumns") @mock.patch("luigi.contrib.redshift.RedshiftTarget") def test_json_copy_with_metadata_columns_enabled(self, mock_redshift_target, mock_add_columns, mock_update_columns, mock_metadata_columns_enabled): task = DummyS3CopyJSONToTableBase() task.run() self.assertTrue(mock_add_columns.called) self.assertTrue(mock_update_columns.called) @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=False) @mock.patch("luigi.contrib.redshift.S3CopyToTable._add_metadata_columns") @mock.patch("luigi.contrib.redshift.S3CopyToTable.post_copy_metacolumns") @mock.patch("luigi.contrib.redshift.RedshiftTarget") def test_json_copy_with_metadata_columns_disabled(self, mock_redshift_target, mock_add_columns, mock_update_columns, mock_metadata_columns_enabled): task = DummyS3CopyJSONToTableBase() task.run() self.assertFalse(mock_add_columns.called) self.assertFalse(mock_update_columns.called) @pytest.mark.aws class TestS3CopyToTable(unittest.TestCase): @mock.patch("luigi.contrib.redshift.RedshiftTarget") def test_copy_missing_creds(self, mock_redshift_target): # Make sure credentials are not set as env vars try: del os.environ["AWS_ACCESS_KEY_ID"] del os.environ["AWS_SECRET_ACCESS_KEY"] except KeyError: pass task = DummyS3CopyToTableBase() # The mocked connection cursor passed to # S3CopyToTable.copy(self, cursor, f). mock_cursor = mock_redshift_target.return_value.connect.return_value.cursor.return_value with self.assertRaises(NotImplementedError): task.copy(mock_cursor, task.s3_load_path()) @mock.patch("luigi.contrib.redshift.S3CopyToTable.copy") @mock.patch("luigi.contrib.redshift.RedshiftTarget") def test_s3_copy_to_table(self, mock_redshift_target, mock_copy): task = DummyS3CopyToTableKey() task.run() # The mocked connection cursor passed to # S3CopyToTable.copy(self, cursor, f). mock_cursor = mock_redshift_target.return_value.connect.return_value.cursor.return_value # `mock_redshift_target` is the mocked `RedshiftTarget` object # returned by S3CopyToTable.output(self). mock_redshift_target.assert_called_with( database=task.database, host=task.host, update_id=task.task_id, user=task.user, table=task.table, password=task.password ) # Check if the `S3CopyToTable.s3_load_path` class attribute was # successfully referenced in the `S3CopyToTable.run` method, which is # in-turn passed to `S3CopyToTable.copy` and other functions in `run` # (see issue #995). mock_copy.assert_called_with(mock_cursor, task.s3_load_path()) # Check the SQL query in `S3CopyToTable.does_table_exist`. mock_cursor.execute.assert_called_with("select 1 as table_exists from pg_table_def where tablename = lower(%s) limit 1", (task.table,)) return @mock.patch("luigi.contrib.redshift.S3CopyToTable.does_table_exist", return_value=False) @mock.patch("luigi.contrib.redshift.RedshiftTarget") def test_s3_copy_to_missing_table(self, mock_redshift_target, mock_does_exist): """ Test missing table creation """ # Ensure `S3CopyToTable.create_table` does not throw an error. task = DummyS3CopyToTableKey() task.run() # Make sure the cursor was successfully used to create the table in # `create_table` as expected. mock_cursor = mock_redshift_target.return_value.connect.return_value.cursor.return_value assert mock_cursor.execute.call_args_list[0][0][0].startswith("CREATE TABLE %s" % task.table) return @mock.patch("luigi.contrib.redshift.S3CopyToTable.does_schema_exist", return_value=False) @mock.patch("luigi.contrib.redshift.RedshiftTarget") def test_s3_copy_to_missing_schema(self, mock_redshift_target, mock_does_exist): task = DummyS3CopyToTableKey(table="schema.table_with_schema") task.run() mock_cursor = mock_redshift_target.return_value.connect.return_value.cursor.return_value executed_query = mock_cursor.execute.call_args_list[0][0][0] assert executed_query.startswith("CREATE SCHEMA IF NOT EXISTS schema") @mock.patch("luigi.contrib.redshift.S3CopyToTable.does_schema_exist", return_value=False) @mock.patch("luigi.contrib.redshift.RedshiftTarget") def test_s3_copy_to_missing_schema_with_no_schema(self, mock_redshift_target, mock_does_exist): task = DummyS3CopyToTableKey(table="table_with_no_schema") task.run() mock_cursor = mock_redshift_target.return_value.connect.return_value.cursor.return_value executed_query = mock_cursor.execute.call_args_list[0][0][0] assert not executed_query.startswith("CREATE SCHEMA IF NOT EXISTS") @mock.patch("luigi.contrib.redshift.S3CopyToTable.does_schema_exist", return_value=True) @mock.patch("luigi.contrib.redshift.RedshiftTarget") def test_s3_copy_to_existing_schema_with_schema(self, mock_redshift_target, mock_does_exist): task = DummyS3CopyToTableKey(table="schema.table_with_schema") task.run() mock_cursor = mock_redshift_target.return_value.connect.return_value.cursor.return_value executed_query = mock_cursor.execute.call_args_list[0][0][0] assert not executed_query.startswith("CREATE SCHEMA IF NOT EXISTS") @mock.patch("luigi.contrib.redshift.S3CopyToTable.does_table_exist", return_value=False) @mock.patch("luigi.contrib.redshift.RedshiftTarget") def test_s3_copy_to_missing_table_with_compression_encodings(self, mock_redshift_target, mock_does_exist): """ Test missing table creation with compression encodings """ # Ensure `S3CopyToTable.create_table` does not throw an error. task = DummyS3CopyToTableWithCompressionEncodings() task.run() # Make sure the cursor was successfully used to create the table in # `create_table` as expected. mock_cursor = mock_redshift_target.return_value.connect.return_value.cursor.return_value encode_string = ",".join("{name} {type} ENCODE {encoding}".format(name=name, type=type, encoding=encoding) for name, type, encoding in task.columns) assert mock_cursor.execute.call_args_list[0][0][0].startswith("CREATE TABLE %s (%s )" % (task.table, encode_string)) return @mock.patch("luigi.contrib.redshift.S3CopyToTable.does_table_exist", return_value=False) @mock.patch("luigi.contrib.redshift.RedshiftTarget") def test_s3_copy_to_missing_table_with_table_constraints(self, mock_redshift_target, mock_does_exist): table_constraints = "PRIMARY KEY (COL1, COL2)" task = DummyS3CopyToTableKey(table_constraints=table_constraints) task.run() mock_cursor = mock_redshift_target.return_value.connect.return_value.cursor.return_value columns_string = ",".join("{name} {type}".format(name=name, type=type) for name, type in task.columns) executed_query = mock_cursor.execute.call_args_list[0][0][0] expectation = "CREATE TABLE %s (%s , PRIMARY KEY (COL1, COL2))" % (task.table, columns_string) assert executed_query.startswith(expectation) @mock.patch("luigi.contrib.redshift.S3CopyToTable.copy") @mock.patch("luigi.contrib.redshift.RedshiftTarget") def test_s3_copy_to_temp_table(self, mock_redshift_target, mock_copy): task = DummyS3CopyToTempTable() task.run() # The mocked connection cursor passed to # S3CopyToTable.copy(self, cursor, f). mock_cursor = mock_redshift_target.return_value.connect.return_value.cursor.return_value # `mock_redshift_target` is the mocked `RedshiftTarget` object # returned by S3CopyToTable.output(self). mock_redshift_target.assert_called_once_with( database=task.database, host=task.host, update_id=task.task_id, user=task.user, table=task.table, password=task.password, ) # Check if the `S3CopyToTable.s3_load_path` class attribute was # successfully referenced in the `S3CopyToTable.run` method, which is # in-turn passed to `S3CopyToTable.copy` and other functions in `run` # (see issue #995). mock_copy.assert_called_once_with(mock_cursor, task.s3_load_path()) # Check the SQL query in `S3CopyToTable.does_table_exist`. # temp table mock_cursor.execute.assert_any_call( "select 1 as table_exists from pg_table_def where tablename = lower(%s) limit 1", (task.table,), ) @mock.patch("luigi.contrib.redshift.RedshiftTarget") def test_s3_copy_with_valid_columns(self, mock_redshift_target): task = DummyS3CopyToTableKey() task.run() # The mocked connection cursor passed to # S3CopyToTable.copy(self, cursor, f). mock_cursor = mock_redshift_target.return_value.connect.return_value.cursor.return_value # `mock_redshift_target` is the mocked `RedshiftTarget` object # returned by S3CopyToTable.output(self). mock_redshift_target.assert_called_once_with( database=task.database, host=task.host, update_id=task.task_id, user=task.user, table=task.table, password=task.password, ) # To get the proper intendation in the multiline `COPY` statement the # SQL string was copied from redshift.py. mock_cursor.execute.assert_called_with( """ COPY {table} {colnames} from '{source}' CREDENTIALS '{creds}' {options} ;""".format( table="dummy_table", colnames="(some_text,some_int)", source="s3://bucket/key", creds="aws_access_key_id=key;aws_secret_access_key=secret", options="", ) ) @mock.patch("luigi.contrib.redshift.RedshiftTarget") def test_s3_copy_with_default_columns(self, mock_redshift_target): task = DummyS3CopyToTableKey(columns=[]) task.run() # The mocked connection cursor passed to # S3CopyToTable.copy(self, cursor, f). mock_cursor = mock_redshift_target.return_value.connect.return_value.cursor.return_value # `mock_redshift_target` is the mocked `RedshiftTarget` object # returned by S3CopyToTable.output(self). mock_redshift_target.assert_called_once_with( database=task.database, host=task.host, update_id=task.task_id, user=task.user, table=task.table, password=task.password, ) # To get the proper intendation in the multiline `COPY` statement the # SQL string was copied from redshift.py. mock_cursor.execute.assert_called_with( """ COPY {table} {colnames} from '{source}' CREDENTIALS '{creds}' {options} ;""".format(table="dummy_table", colnames="", source="s3://bucket/key", creds="aws_access_key_id=key;aws_secret_access_key=secret", options="") ) @mock.patch("luigi.contrib.redshift.RedshiftTarget") def test_s3_copy_with_nonetype_columns(self, mock_redshift_target): task = DummyS3CopyToTableKey(columns=None) task.run() # The mocked connection cursor passed to # S3CopyToTable.copy(self, cursor, f). mock_cursor = mock_redshift_target.return_value.connect.return_value.cursor.return_value # `mock_redshift_target` is the mocked `RedshiftTarget` object # returned by S3CopyToTable.output(self). mock_redshift_target.assert_called_once_with( database=task.database, host=task.host, update_id=task.task_id, user=task.user, table=task.table, password=task.password, ) # To get the proper intendation in the multiline `COPY` statement the # SQL string was copied from redshift.py. mock_cursor.execute.assert_called_with( """ COPY {table} {colnames} from '{source}' CREDENTIALS '{creds}' {options} ;""".format(table="dummy_table", colnames="", source="s3://bucket/key", creds="aws_access_key_id=key;aws_secret_access_key=secret", options="") ) @pytest.mark.aws class TestS3CopyToSchemaTable(unittest.TestCase): @mock.patch("luigi.contrib.redshift.S3CopyToTable.copy") @mock.patch("luigi.contrib.redshift.RedshiftTarget") def test_s3_copy_to_table(self, mock_redshift_target, mock_copy): task = DummyS3CopyToTableKey(table="dummy_schema.dummy_table") task.run() # The mocked connection cursor passed to # S3CopyToTable.copy(self, cursor, f). mock_cursor = mock_redshift_target.return_value.connect.return_value.cursor.return_value # Check the SQL query in `S3CopyToTable.does_table_exist`. mock_cursor.execute.assert_called_with( "select 1 as table_exists from information_schema.tables where table_schema = lower(%s) and table_name = lower(%s) limit 1", tuple(task.table.split(".")), ) class DummyRedshiftUnloadTask(luigi.contrib.redshift.RedshiftUnloadTask): # Class attributes taken from `DummyPostgresImporter` in # `../postgres_test.py`. host = "dummy_host" database = "dummy_database" user = "dummy_user" password = "dummy_password" table = luigi.Parameter(default="dummy_table") columns = ( ("some_text", "varchar(255)"), ("some_int", "int"), ) aws_access_key_id = "AWS_ACCESS_KEY" aws_secret_access_key = "AWS_SECRET_KEY" s3_unload_path = "s3://%s/%s" % (BUCKET, KEY) unload_options = "DELIMITER ',' ADDQUOTES GZIP ALLOWOVERWRITE PARALLEL OFF" def query(self): return "SELECT 'a' as col_a, current_date as col_b" @pytest.mark.aws class TestRedshiftUnloadTask(unittest.TestCase): @mock.patch("luigi.contrib.redshift.RedshiftTarget") def test_redshift_unload_command(self, mock_redshift_target): task = DummyRedshiftUnloadTask() task.run() # The mocked connection cursor passed to # RedshiftUnloadTask. mock_cursor = mock_redshift_target.return_value.connect.return_value.cursor.return_value # Check the Unload query. mock_cursor.execute.assert_called_with( "UNLOAD ( 'SELECT \\'a\\' as col_a, current_date as col_b' ) TO 's3://bucket/key' " "credentials 'aws_access_key_id=AWS_ACCESS_KEY;aws_secret_access_key=AWS_SECRET_KEY' " "DELIMITER ',' ADDQUOTES GZIP ALLOWOVERWRITE PARALLEL OFF;" ) class DummyRedshiftAutocommitQuery(luigi.contrib.redshift.RedshiftQuery): # Class attributes taken from `DummyPostgresImporter` in # `../postgres_test.py`. host = "dummy_host" database = "dummy_database" user = "dummy_user" password = "dummy_password" table = luigi.Parameter(default="dummy_table") autocommit = True def query(self): return "SELECT 'a' as col_a, current_date as col_b" @pytest.mark.aws class TestRedshiftAutocommitQuery(unittest.TestCase): @mock.patch("luigi.contrib.redshift.RedshiftTarget") def test_redshift_autocommit_query(self, mock_redshift_target): task = DummyRedshiftAutocommitQuery() task.run() # The mocked connection cursor passed to # RedshiftUnloadTask. mock_connect = mock_redshift_target.return_value.connect.return_value # Check the Unload query. self.assertTrue(mock_connect.autocommit) @pytest.mark.aws class TestRedshiftManifestTask(unittest.TestCase): def test_run(self): with mock_s3(): client = S3Client() client.s3.meta.client.create_bucket(Bucket=BUCKET) for key in FILES: k = "%s/%s" % (KEY, key) client.put_string("", "s3://%s/%s" % (BUCKET, k)) folder_path = "s3://%s/%s" % (BUCKET, KEY) path = "s3://%s/%s/%s" % (BUCKET, "manifest", "test.manifest") folder_paths = [folder_path] m = mock.mock_open() with mock.patch("luigi.contrib.s3.S3Target.open", m, create=True): t = redshift.RedshiftManifestTask(path, folder_paths) luigi.build([t], local_scheduler=True) expected_manifest_output = json.dumps(generate_manifest_json(folder_paths, FILES)) handle = m() handle.write.assert_called_with(expected_manifest_output) def test_run_multiple_paths(self): with mock_s3(): client = S3Client() client.s3.meta.client.create_bucket(Bucket=BUCKET) for parent in [KEY, KEY_2]: for key in FILES: k = "%s/%s" % (parent, key) client.put_string("", "s3://%s/%s" % (BUCKET, k)) folder_path_1 = "s3://%s/%s" % (BUCKET, KEY) folder_path_2 = "s3://%s/%s" % (BUCKET, KEY_2) folder_paths = [folder_path_1, folder_path_2] path = "s3://%s/%s/%s" % (BUCKET, "manifest", "test.manifest") m = mock.mock_open() with mock.patch("luigi.contrib.s3.S3Target.open", m, create=True): t = redshift.RedshiftManifestTask(path, folder_paths) luigi.build([t], local_scheduler=True) expected_manifest_output = json.dumps(generate_manifest_json(folder_paths, FILES)) handle = m() handle.write.assert_called_with(expected_manifest_output) ================================================ FILE: test/contrib/s3_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright (c) 2013 Mortar Data # # Licensed under the Apache License, Version 2.0 (the "License"); you may not # use this file except in compliance with the License. You may obtain a copy of # the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations under # the License. # import os import sys import tempfile import boto3 if sys.version_info[:2] <= (3, 11): from boto.s3 import key import pytest from botocore.exceptions import ClientError from helpers import skipOnTravisAndGithubActions, unittest, with_config from mock import patch from moto import mock_s3, mock_sts from target_test import FileSystemTargetTestMixin from luigi.contrib.s3 import DeprecatedBotoClientException, FileNotFoundException, InvalidDeleteException, S3Client, S3Target from luigi.target import MissingParentDirectory if (3, 4, 0) <= sys.version_info[:3] < (3, 4, 3): # spulec/moto#308 raise unittest.SkipTest("moto mock doesn't work with python3.4") AWS_ACCESS_KEY = "XXXXXXXXXXXXXXXXXXXX" AWS_SECRET_KEY = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" AWS_SESSION_TOKEN = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" def create_bucket(): conn = boto3.resource("s3", region_name="us-east-1") # We need to create the bucket since this is all in Moto's 'virtual' AWS account conn.create_bucket(Bucket="mybucket") return conn @pytest.mark.aws class TestS3Target(unittest.TestCase, FileSystemTargetTestMixin): def setUp(self): f = tempfile.NamedTemporaryFile(mode="wb", delete=False) self.tempFileContents = b"I'm a temporary file for testing\nAnd this is the second line\nThis is the third." self.tempFilePath = f.name f.write(self.tempFileContents) f.close() self.addCleanup(os.remove, self.tempFilePath) self.mock_s3 = mock_s3() self.mock_s3.start() self.addCleanup(self.mock_s3.stop) def create_target(self, format=None, **kwargs): client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) create_bucket() return S3Target("s3://mybucket/test_file", client=client, format=format, **kwargs) def create_target_with_session(self, format=None, **kwargs): client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY, AWS_SESSION_TOKEN) create_bucket() return S3Target("s3://mybucket/test_file", client=client, format=format, **kwargs) def test_read(self): client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) create_bucket() client.put(self.tempFilePath, "s3://mybucket/tempfile") t = S3Target("s3://mybucket/tempfile", client=client) read_file = t.open() file_str = read_file.read() self.assertEqual(self.tempFileContents, file_str.encode("utf-8")) def test_read_with_session(self): client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY, AWS_SESSION_TOKEN) create_bucket() client.put(self.tempFilePath, "s3://mybucket/tempfile-with-session") t = S3Target("s3://mybucket/tempfile-with-session", client=client) read_file = t.open() file_str = read_file.read() self.assertEqual(self.tempFileContents, file_str.encode("utf-8")) def test_read_no_file(self): t = self.create_target() self.assertRaises(FileNotFoundException, t.open) def test_read_no_file_with_session(self): t = self.create_target_with_session() self.assertRaises(FileNotFoundException, t.open) def test_read_no_file_sse(self): t = self.create_target(encrypt_key=True) self.assertRaises(FileNotFoundException, t.open) @unittest.skipIf(tuple(sys.version_info) >= (3, 12), "boto is not supported on Python 3.12+") def test_read_iterator_long(self): # write a file that is 5X the boto buffersize # to test line buffering old_buffer = key.Key.BufferSize key.Key.BufferSize = 2 try: tempf = tempfile.NamedTemporaryFile(mode="wb", delete=False) temppath = tempf.name firstline = "".zfill(key.Key.BufferSize * 5) + os.linesep secondline = "line two" + os.linesep thirdline = "line three" + os.linesep contents = firstline + secondline + thirdline tempf.write(contents.encode("utf-8")) tempf.close() client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) create_bucket() remote_path = "s3://mybucket/largetempfile" client.put(temppath, remote_path) t = S3Target(remote_path, client=client) with t.open() as read_file: lines = [line for line in read_file] finally: key.Key.BufferSize = old_buffer self.assertEqual(3, len(lines)) self.assertEqual(firstline, lines[0]) self.assertEqual(secondline, lines[1]) self.assertEqual(thirdline, lines[2]) def test_get_path(self): t = self.create_target() path = t.path self.assertEqual("s3://mybucket/test_file", path) def test_get_path_sse(self): t = self.create_target(encrypt_key=True) path = t.path self.assertEqual("s3://mybucket/test_file", path) @pytest.mark.aws class TestS3Client(unittest.TestCase): def setUp(self): f = tempfile.NamedTemporaryFile(mode="wb", delete=False) self.tempFilePath = f.name self.tempFileContents = b"I'm a temporary file for testing\n" f.write(self.tempFileContents) f.close() self.addCleanup(os.remove, self.tempFilePath) self.mock_s3 = mock_s3() self.mock_s3.start() self.mock_sts = mock_sts() self.mock_sts.start() self.addCleanup(self.mock_s3.stop) self.addCleanup(self.mock_sts.stop) @patch("boto3.resource") def test_init_without_init_or_config(self, mock): """If no config or arn provided, boto3 client should be called with default parameters. Delegating ENV or Task Role credential handling to boto3 itself. """ S3Client().s3 mock.assert_called_with("s3", aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None) @with_config({"s3": {"aws_access_key_id": "foo", "aws_secret_access_key": "bar"}}) @patch("boto3.resource") def test_init_with_config(self, mock): S3Client().s3 mock.assert_called_with("s3", aws_access_key_id="foo", aws_secret_access_key="bar", aws_session_token=None) @patch("boto3.resource") @patch("boto3.client") @with_config({"s3": {"aws_role_arn": "role", "aws_role_session_name": "name"}}) def test_init_with_config_and_roles(self, sts_mock, s3_mock): S3Client().s3 sts_mock.client.assume_role.called_with(RoleArn="role", RoleSessionName="name") @patch("boto3.client") def test_init_with_host_deprecated(self, mock): with self.assertRaises(DeprecatedBotoClientException): S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY, host="us-east-1").s3 def test_put(self): create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) s3_client.put(self.tempFilePath, "s3://mybucket/putMe") self.assertTrue(s3_client.exists("s3://mybucket/putMe")) s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY, AWS_SESSION_TOKEN) s3_client.put(self.tempFilePath, "s3://mybucket/putMe") self.assertTrue(s3_client.exists("s3://mybucket/putMe")) def test_put_no_such_bucket(self): # intentionally don't create bucket s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) with self.assertRaises(s3_client.s3.meta.client.exceptions.NoSuchBucket): s3_client.put(self.tempFilePath, "s3://mybucket/putMe") def test_put_sse_deprecated(self): create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) with self.assertRaises(DeprecatedBotoClientException): s3_client.put(self.tempFilePath, "s3://mybucket/putMe", encrypt_key=True) def test_put_host_deprecated(self): create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) with self.assertRaises(DeprecatedBotoClientException): s3_client.put(self.tempFilePath, "s3://mybucket/putMe", host="us-east-1") def test_put_string(self): create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) s3_client.put_string("SOMESTRING", "s3://mybucket/putString") self.assertTrue(s3_client.exists("s3://mybucket/putString")) def test_put_string_no_such_bucket(self): # intentionally don't create bucket s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) with self.assertRaises(s3_client.s3.meta.client.exceptions.NoSuchBucket): s3_client.put_string("SOMESTRING", "s3://mybucket/putString") def test_put_string_sse_deprecated(self): create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) with self.assertRaises(DeprecatedBotoClientException): s3_client.put("SOMESTRING", "s3://mybucket/putMe", encrypt_key=True) def test_put_string_host_deprecated(self): create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) with self.assertRaises(DeprecatedBotoClientException): s3_client.put("SOMESTRING", "s3://mybucket/putMe", host="us-east-1") @skipOnTravisAndGithubActions("passes and fails intermitantly, suspecting it's a race condition not handled by moto") def test_put_multipart_multiple_parts_non_exact_fit(self): """ Test a multipart put with two parts, where the parts are not exactly the split size. """ # 5MB is minimum part size part_size = 8388608 file_size = (part_size * 2) - 1000 return self._run_multipart_test(part_size, file_size) @skipOnTravisAndGithubActions("passes and fails intermitantly, suspecting it's a race condition not handled by moto") def test_put_multipart_multiple_parts_exact_fit(self): """ Test a multipart put with multiple parts, where the parts are exactly the split size. """ # 5MB is minimum part size part_size = 8388608 file_size = part_size * 2 return self._run_multipart_test(part_size, file_size) def test_put_multipart_multiple_parts_with_sse_deprecated(self): s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) with self.assertRaises(DeprecatedBotoClientException): s3_client.put_multipart("path", "path", encrypt_key=True) def test_put_multipart_multiple_parts_with_host_deprecated(self): s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) with self.assertRaises(DeprecatedBotoClientException): s3_client.put_multipart("path", "path", host="us-east-1") def test_put_multipart_empty_file(self): """ Test a multipart put with an empty file. """ # 5MB is minimum part size part_size = 8388608 file_size = 0 return self._run_multipart_test(part_size, file_size) def test_put_multipart_less_than_split_size(self): """ Test a multipart put with a file smaller than split size; should revert to regular put. """ # 5MB is minimum part size part_size = 8388608 file_size = 5000 return self._run_multipart_test(part_size, file_size) def test_put_multipart_no_such_bucket(self): # intentionally don't create bucket s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) with self.assertRaises(s3_client.s3.meta.client.exceptions.NoSuchBucket): s3_client.put_multipart(self.tempFilePath, "s3://mybucket/putMe") def test_exists(self): create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) self.assertTrue(s3_client.exists("s3://mybucket/")) self.assertTrue(s3_client.exists("s3://mybucket")) self.assertFalse(s3_client.exists("s3://mybucket/nope")) self.assertFalse(s3_client.exists("s3://mybucket/nope/")) s3_client.put(self.tempFilePath, "s3://mybucket/tempfile") self.assertTrue(s3_client.exists("s3://mybucket/tempfile")) self.assertFalse(s3_client.exists("s3://mybucket/temp")) s3_client.put(self.tempFilePath, "s3://mybucket/tempdir0_$folder$") self.assertTrue(s3_client.exists("s3://mybucket/tempdir0")) s3_client.put(self.tempFilePath, "s3://mybucket/tempdir1/") self.assertTrue(s3_client.exists("s3://mybucket/tempdir1")) s3_client.put(self.tempFilePath, "s3://mybucket/tempdir2/subdir") self.assertTrue(s3_client.exists("s3://mybucket/tempdir2")) self.assertFalse(s3_client.exists("s3://mybucket/tempdir")) def test_get(self): create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) s3_client.put(self.tempFilePath, "s3://mybucket/putMe") tmp_file = tempfile.NamedTemporaryFile(delete=True) tmp_file_path = tmp_file.name s3_client.get("s3://mybucket/putMe", tmp_file_path) with open(tmp_file_path, "r") as f: content = f.read() self.assertEqual(content, self.tempFileContents.decode("utf-8")) tmp_file.close() def test_get_as_bytes(self): create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) s3_client.put(self.tempFilePath, "s3://mybucket/putMe") contents = s3_client.get_as_bytes("s3://mybucket/putMe") self.assertEqual(contents, self.tempFileContents) def test_get_as_string(self): create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) s3_client.put(self.tempFilePath, "s3://mybucket/putMe2") contents = s3_client.get_as_string("s3://mybucket/putMe2") self.assertEqual(contents, self.tempFileContents.decode("utf-8")) def test_get_as_string_latin1(self): create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) s3_client.put(self.tempFilePath, "s3://mybucket/putMe3") contents = s3_client.get_as_string("s3://mybucket/putMe3", encoding="ISO-8859-1") self.assertEqual(contents, self.tempFileContents.decode("ISO-8859-1")) def test_get_key(self): create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) s3_client.put(self.tempFilePath, "s3://mybucket/key_to_find") self.assertTrue(s3_client.get_key("s3://mybucket/key_to_find").key) self.assertFalse(s3_client.get_key("s3://mybucket/does_not_exist")) def test_isdir(self): create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) self.assertTrue(s3_client.isdir("s3://mybucket")) s3_client.put(self.tempFilePath, "s3://mybucket/tempdir0_$folder$") self.assertTrue(s3_client.isdir("s3://mybucket/tempdir0")) s3_client.put(self.tempFilePath, "s3://mybucket/tempdir1/") self.assertTrue(s3_client.isdir("s3://mybucket/tempdir1")) s3_client.put(self.tempFilePath, "s3://mybucket/key") self.assertFalse(s3_client.isdir("s3://mybucket/key")) def test_mkdir(self): create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) self.assertTrue(s3_client.isdir("s3://mybucket")) s3_client.mkdir("s3://mybucket") s3_client.mkdir("s3://mybucket/dir") self.assertTrue(s3_client.isdir("s3://mybucket/dir")) self.assertRaises(MissingParentDirectory, s3_client.mkdir, "s3://mybucket/dir/foo/bar", parents=False) self.assertFalse(s3_client.isdir("s3://mybucket/dir/foo/bar")) def test_listdir(self): create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) s3_client.put_string("", "s3://mybucket/hello/frank") s3_client.put_string("", "s3://mybucket/hello/world") self.assertEqual(["s3://mybucket/hello/frank", "s3://mybucket/hello/world"], list(s3_client.listdir("s3://mybucket/hello"))) def test_list(self): create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) s3_client.put_string("", "s3://mybucket/hello/frank") s3_client.put_string("", "s3://mybucket/hello/world") self.assertEqual(["frank", "world"], list(s3_client.list("s3://mybucket/hello"))) def test_listdir_key(self): create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) s3_client.put_string("", "s3://mybucket/hello/frank") s3_client.put_string("", "s3://mybucket/hello/world") self.assertEqual( [True, True], [s3_client.exists("s3://" + x.bucket_name + "/" + x.key) for x in s3_client.listdir("s3://mybucket/hello", return_key=True)] ) def test_list_key(self): create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) s3_client.put_string("", "s3://mybucket/hello/frank") s3_client.put_string("", "s3://mybucket/hello/world") self.assertEqual( [True, True], [s3_client.exists("s3://" + x.bucket_name + "/" + x.key) for x in s3_client.listdir("s3://mybucket/hello", return_key=True)] ) def test_remove_bucket_dne(self): create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) self.assertRaises(ClientError, lambda: s3_client.remove("s3://bucketdoesnotexist/file")) def test_remove_file_dne(self): create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) self.assertFalse(s3_client.remove("s3://mybucket/doesNotExist")) def test_remove_file(self): create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) s3_client.put(self.tempFilePath, "s3://mybucket/existingFile0") self.assertTrue(s3_client.remove("s3://mybucket/existingFile0")) self.assertFalse(s3_client.exists("s3://mybucket/existingFile0")) def test_remove_invalid(self): create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) self.assertRaises(InvalidDeleteException, lambda: s3_client.remove("s3://mybucket/")) def test_remove_invalid_no_slash(self): create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) self.assertRaises(InvalidDeleteException, lambda: s3_client.remove("s3://mybucket")) def test_remove_dir_not_recursive(self): create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) s3_client.put(self.tempFilePath, "s3://mybucket/removemedir/file") self.assertRaises(InvalidDeleteException, lambda: s3_client.remove("s3://mybucket/removemedir", recursive=False)) def test_remove_dir(self): create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) # test that the marker file created by Hadoop S3 Native FileSystem is removed s3_client.put(self.tempFilePath, "s3://mybucket/removemedir/file") s3_client.put_string("", "s3://mybucket/removemedir_$folder$") self.assertTrue(s3_client.remove("s3://mybucket/removemedir")) self.assertFalse(s3_client.exists("s3://mybucket/removemedir_$folder$")) def test_remove_dir_batch(self): create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) for i in range(0, 2000): s3_client.put(self.tempFilePath, "s3://mybucket/removemedir/file{i}".format(i=i)) self.assertTrue(s3_client.remove("s3://mybucket/removemedir/")) self.assertFalse(s3_client.exists("s3://mybucket/removedir/")) @skipOnTravisAndGithubActions("passes and fails intermitantly, suspecting it's a race condition not handled by moto") def test_copy_multiple_parts_non_exact_fit(self): """ Test a multipart put with two parts, where the parts are not exactly the split size. """ # First, put a file into S3 self._run_copy_test(self.test_put_multipart_multiple_parts_non_exact_fit) @skipOnTravisAndGithubActions("passes and fails intermitantly, suspecting it's a race condition not handled by moto") def test_copy_multiple_parts_exact_fit(self): """ Test a copy multiple parts, where the parts are exactly the split size. """ self._run_copy_test(self.test_put_multipart_multiple_parts_exact_fit) def test_copy_less_than_split_size(self): """ Test a copy with a file smaller than split size; should revert to regular put. """ self._run_copy_test(self.test_put_multipart_less_than_split_size) def test_copy_empty_file(self): """ Test a copy with an empty file. """ self._run_copy_test(self.test_put_multipart_empty_file) @mock_s3 def test_copy_empty_dir(self): """ Test copying an empty dir """ create_bucket() s3_dir = "s3://mybucket/copydir/" s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) s3_client.mkdir(s3_dir) self.assertTrue(s3_client.exists(s3_dir)) s3_dest = "s3://mybucket/copydir_new/" response = s3_client.copy(s3_dir, s3_dest) self._run_copy_response_test(response, expected_num=0, expected_size=0) @mock_s3 @skipOnTravisAndGithubActions("https://travis-ci.org/spotify/luigi/jobs/145895385") def test_copy_dir(self): """ Test copying 20 files from one folder to another """ create_bucket() n = 20 copy_part_size = (1024**2) * 5 # Note we can't test the multipart copy due to moto issue #526 # so here I have to keep the file size smaller than the copy_part_size file_size = 5000 s3_dir = "s3://mybucket/copydir/" file_contents = b"a" * file_size tmp_file = tempfile.NamedTemporaryFile(mode="wb", delete=True) tmp_file_path = tmp_file.name tmp_file.write(file_contents) tmp_file.flush() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) for i in range(n): file_path = s3_dir + str(i) s3_client.put_multipart(tmp_file_path, file_path) self.assertTrue(s3_client.exists(file_path)) s3_dest = "s3://mybucket/copydir_new/" response = s3_client.copy(s3_dir, s3_dest, threads=10, part_size=copy_part_size) self._run_copy_response_test(response, expected_num=n, expected_size=(n * file_size)) for i in range(n): original_size = s3_client.get_key(s3_dir + str(i)).size copy_size = s3_client.get_key(s3_dest + str(i)).size self.assertEqual(original_size, copy_size) def test__path_to_bucket_and_key(self): self.assertEqual(("bucket", "key"), S3Client._path_to_bucket_and_key("s3://bucket/key")) def test__path_to_bucket_and_key_with_question_mark(self): self.assertEqual(("bucket", "key?blade"), S3Client._path_to_bucket_and_key("s3://bucket/key?blade")) @mock_s3 def _run_copy_test(self, put_method, is_multipart=False): create_bucket() # Run the method to put the file into s3 into the first place expected_num, expected_size = put_method() # As all the multipart put methods use `self._run_multipart_test` # we can just use this key original = "s3://mybucket/putMe" copy = "s3://mybucket/putMe_copy" # Copy the file from old location to new s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) if is_multipart: # 5MB is minimum part size, use it here so we don't have to generate huge files to test # the multipart upload in moto part_size = (1024**2) * 5 response = s3_client.copy(original, copy, part_size=part_size, threads=4) else: response = s3_client.copy(original, copy, threads=4) self._run_copy_response_test(response, expected_num=expected_num, expected_size=expected_size) # We can't use etags to compare between multipart and normal keys, # so we fall back to using the file size original_size = s3_client.get_key(original).size copy_size = s3_client.get_key(copy).size self.assertEqual(original_size, copy_size) @mock_s3 def _run_multipart_test(self, part_size, file_size, **kwargs): create_bucket() file_contents = b"a" * file_size s3_path = "s3://mybucket/putMe" tmp_file = tempfile.NamedTemporaryFile(mode="wb", delete=True) tmp_file_path = tmp_file.name tmp_file.write(file_contents) tmp_file.flush() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) s3_client.put_multipart(tmp_file_path, s3_path, part_size=part_size, **kwargs) self.assertTrue(s3_client.exists(s3_path)) file_size = os.path.getsize(tmp_file.name) key_size = s3_client.get_key(s3_path).size self.assertEqual(file_size, key_size) tmp_file.close() return 1, key_size def _run_copy_response_test(self, response, expected_num=None, expected_size=None): num, size = response self.assertIsInstance(response, tuple) # only check >= minimum possible value if not provided expected value self.assertEqual(num, expected_num) if expected_num is not None else self.assertGreaterEqual(num, 1) self.assertEqual(size, expected_size) if expected_size is not None else self.assertGreaterEqual(size, 0) ================================================ FILE: test/contrib/salesforce_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright (c) 2016 Simply Measured # # Licensed under the Apache License, Version 2.0 (the "License"); you may not # use this file except in compliance with the License. You may obtain a copy of # the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations under # the License. # # This method will be used by the mock to replace requests.get """ Unit test for the Salesforce contrib package """ import re import mock import pytest from helpers import unittest from luigi.contrib.salesforce import QuerySalesforce, SalesforceAPI from luigi.mock import MockTarget def mocked_requests_get(*args, **kwargs): class MockResponse: def __init__(self, body, status_code): self.body = body self.status_code = status_code @property def text(self): return self.body def raise_for_status(self): return None result_list = ( '123412351236' ) return MockResponse(result_list, 200) # Keep open around so we can use it in the mock responses old__open = open def mocked_open(*args, **kwargs): if re.match("job_data", str(args[0])): return MockTarget(args[0]).open(args[1]) else: return old__open(*args) @pytest.mark.contrib class TestSalesforceAPI(unittest.TestCase): # We patch 'requests.get' with our own method. The mock object is passed in to our test case method. @mock.patch("requests.get", side_effect=mocked_requests_get) def test_deprecated_results_warning(self, mock_get): sf = SalesforceAPI("xx", "xx", "xx") with self.assertWarnsRegex(UserWarning, r"get_batch_results is deprecated"): result_id = sf.get_batch_results("job_id", "batch_id") self.assertEqual("1234", result_id) @mock.patch("requests.get", side_effect=mocked_requests_get) def test_result_ids(self, mock_get): sf = SalesforceAPI("xx", "xx", "xx") result_ids = sf.get_batch_result_ids("job_id", "batch_id") self.assertEqual(["1234", "1235", "1236"], result_ids) class TestQuerySalesforce(QuerySalesforce): def output(self): return MockTarget("job_data.csv") @property def object_name(self): return "dual" @property def soql(self): return "SELECT * FROM %s" % self.object_name @pytest.mark.contrib class TestSalesforceQuery(unittest.TestCase): @mock.patch("builtins.open", side_effect=mocked_open) def setUp(self, mock_open): MockTarget.fs.clear() self.result_ids = ["a", "b", "c"] counter = 1 self.all_lines = "Lines\n" self.header = "Lines" for i, id in enumerate(self.result_ids): filename = "%s.%d" % ("job_data.csv", i) with MockTarget(filename).open("w") as f: line = "%d line\n%d line" % ((counter), (counter + 1)) f.write(self.header + "\n" + line + "\n") self.all_lines += line + "\n" counter += 2 @mock.patch("builtins.open", side_effect=mocked_open) def test_multi_csv_download(self, mock_open): qsf = TestQuerySalesforce() qsf.merge_batch_results(self.result_ids) self.assertEqual(MockTarget(qsf.output().path).open("r").read(), self.all_lines) ================================================ FILE: test/contrib/scalding_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import os import random import shutil import tempfile import unittest import mock import pytest import luigi from luigi.contrib import scalding class MyScaldingTask(scalding.ScaldingJobTask): scala_source = luigi.Parameter() def source(self): return self.scala_source @pytest.mark.contrib class ScaldingTest(unittest.TestCase): def setUp(self): self.scalding_home = os.path.join(tempfile.gettempdir(), "scalding-%09d" % random.randint(0, 999999999)) os.mkdir(self.scalding_home) self.lib_dir = os.path.join(self.scalding_home, "lib") os.mkdir(self.lib_dir) os.mkdir(os.path.join(self.scalding_home, "provided")) os.mkdir(os.path.join(self.scalding_home, "libjars")) f = open(os.path.join(self.lib_dir, "scalding-core-foo"), "w") f.close() self.scala_source = os.path.join(self.scalding_home, "my_source.scala") f = open(self.scala_source, "w") f.write("class foo extends Job") f.close() os.environ["SCALDING_HOME"] = self.scalding_home def tearDown(self): shutil.rmtree(self.scalding_home) @mock.patch("subprocess.check_call") @mock.patch("luigi.contrib.hadoop.run_and_track_hadoop_job") def test_scalding(self, check_call, track_job): success = luigi.run(["MyScaldingTask", "--scala-source", self.scala_source, "--local-scheduler", "--no-lock"]) self.assertTrue(success) # TODO: check more stuff ================================================ FILE: test/contrib/sge_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import logging import os import os.path import subprocess import unittest from glob import glob import pytest from mock import patch import luigi from luigi.contrib.sge import SGEJobTask, _parse_qstat_state DEFAULT_HOME = "/home" logger = logging.getLogger("luigi-interface") QSTAT_OUTPUT = """job-ID prior name user state submit/start at queue slots ja-task-ID ----------------------------------------------------------------------------------------------------------------- 1 0.55500 job1 root r 07/09/2015 16:56:45 all.q@node001 1 2 0.55500 job2 root qw 07/09/2015 16:56:42 1 3 0.00000 job3 root t 07/09/2015 16:56:45 1 """ def on_sge_master(): try: subprocess.check_output("qstat", shell=True) return True except subprocess.CalledProcessError: return False @pytest.mark.contrib class TestSGEWrappers(unittest.TestCase): def test_track_job(self): """`track_job` returns the state using qstat""" self.assertEqual(_parse_qstat_state(QSTAT_OUTPUT, 1), "r") self.assertEqual(_parse_qstat_state(QSTAT_OUTPUT, 2), "qw") self.assertEqual(_parse_qstat_state(QSTAT_OUTPUT, 3), "t") self.assertEqual(_parse_qstat_state("", 1), "u") self.assertEqual(_parse_qstat_state("", 4), "u") class TestJobTask(SGEJobTask): """Simple SGE job: write a test file to NSF shared drive and waits a minute""" i = luigi.Parameter() def work(self): logger.info("Running test job...") with open(self.output().path, "w") as f: f.write("this is a test\n") def output(self): return luigi.LocalTarget(os.path.join(DEFAULT_HOME, "testfile_" + str(self.i))) @pytest.mark.contrib class TestSGEJob(unittest.TestCase): """Test from SGE master node""" def test_run_job(self): if on_sge_master(): outfile = os.path.join(DEFAULT_HOME, "testfile_1") tasks = [TestJobTask(i=str(i), n_cpu=1) for i in range(3)] luigi.build(tasks, local_scheduler=True, workers=3) self.assertTrue(os.path.exists(outfile)) @patch("subprocess.check_output") def test_run_job_with_dump(self, mock_check_output): mock_check_output.side_effect = ['Your job 12345 ("test_job") has been submitted', ""] task = TestJobTask(i="1", n_cpu=1, shared_tmp_dir="/tmp") luigi.build([task], local_scheduler=True) self.assertEqual(mock_check_output.call_count, 2) def tearDown(self): for fpath in glob(os.path.join(DEFAULT_HOME, "test_file_*")): try: os.remove(fpath) except OSError: pass ================================================ FILE: test/contrib/spark_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import os import pickle import sys import unittest from functools import partial from io import BytesIO from multiprocessing import Value from subprocess import Popen import pytest from helpers import temporary_unloaded_module, with_config from mock import MagicMock, call, mock, patch import luigi import luigi.contrib.hdfs from luigi.contrib.external_program import ExternalProgramRunError from luigi.contrib.spark import PySparkTask, SparkSubmitTask from luigi.mock import MockTarget def poll_generator(): yield None yield 1 def setup_run_process(proc): poll_gen = poll_generator() proc.return_value.poll = lambda: next(poll_gen) proc.return_value.returncode = 0 proc.return_value.stdout = BytesIO() proc.return_value.stderr = BytesIO() class TestSparkSubmitTask(SparkSubmitTask): name = "AppName" entry_class = "org.test.MyClass" jars = ["jars/my.jar"] py_files = ["file1.py", "file2.py"] files = ["file1", "file2"] conf = {"Prop": "Value"} properties_file = "conf/spark-defaults.conf" driver_memory = "4G" driver_java_options = "-Xopt" driver_library_path = "library/path" driver_class_path = "class/path" executor_memory = "8G" driver_cores = 8 supervise = True total_executor_cores = 150 executor_cores = 10 queue = "queue" num_executors = 2 archives = ["archive1", "archive2"] app = "file" pyspark_python = "/a/b/c" pyspark_driver_python = "/b/c/d" hadoop_user_name = "luigiuser" def app_options(self): return ["arg1", "arg2"] def output(self): return luigi.LocalTarget("output") class TestDefaultSparkSubmitTask(SparkSubmitTask): app = "test.py" def output(self): return luigi.LocalTarget("output") class TestPySparkTask(PySparkTask): def input(self): return MockTarget("input") def output(self): return MockTarget("output") def main(self, sc, *args): sc.textFile(self.input().path).saveAsTextFile(self.output().path) class TestPySparkSessionTask(PySparkTask): def input(self): return MockTarget("input") def output(self): return MockTarget("output") def main(self, spark, *args): spark.sql(self.input().path).write.saveAsTable(self.output().path) class MessyNamePySparkTask(TestPySparkTask): name = "AppName(a,b,c,1:2,3/4)" @pytest.mark.apache class SparkSubmitTaskTest(unittest.TestCase): ss = "ss-stub" @with_config({"spark": {"spark-submit": ss, "master": "yarn-client", "hadoop-conf-dir": "path", "deploy-mode": "client"}}) @patch("luigi.contrib.external_program.subprocess.Popen") def test_run(self, proc): setup_run_process(proc) job = TestSparkSubmitTask() job.run() self.assertEqual( proc.call_args[0][0], [ "ss-stub", "--master", "yarn-client", "--deploy-mode", "client", "--name", "AppName", "--class", "org.test.MyClass", "--jars", "jars/my.jar", "--py-files", "file1.py,file2.py", "--files", "file1,file2", "--archives", "archive1,archive2", "--conf", "Prop=Value", "--conf", "spark.pyspark.python=/a/b/c", "--conf", "spark.pyspark.driver.python=/b/c/d", "--properties-file", "conf/spark-defaults.conf", "--driver-memory", "4G", "--driver-java-options", "-Xopt", "--driver-library-path", "library/path", "--driver-class-path", "class/path", "--executor-memory", "8G", "--driver-cores", "8", "--supervise", "--total-executor-cores", "150", "--executor-cores", "10", "--queue", "queue", "--num-executors", "2", "file", "arg1", "arg2", ], ) @with_config({"spark": {"hadoop-conf-dir": "path"}}) @patch("luigi.contrib.external_program.subprocess.Popen") def test_environment_is_set_correctly(self, proc): setup_run_process(proc) job = TestSparkSubmitTask() job.run() assert job._conf == {"Prop": "Value", "spark.pyspark.python": "/a/b/c", "spark.pyspark.driver.python": "/b/c/d"} assert job.program_environment()["HADOOP_USER_NAME"] == "luigiuser" self.assertIn("HADOOP_CONF_DIR", proc.call_args[1]["env"]) self.assertEqual(proc.call_args[1]["env"]["HADOOP_CONF_DIR"], "path") @with_config( { "spark": { "spark-submit": ss, "master": "spark://host:7077", "conf": "prop1=val1", "jars": "jar1.jar,jar2.jar", "files": "file1,file2", "py-files": "file1.py,file2.py", "archives": "archive1", } } ) @patch("luigi.contrib.external_program.subprocess.Popen") def test_defaults(self, proc): proc.return_value.returncode = 0 job = TestDefaultSparkSubmitTask() job.run() self.assertEqual( proc.call_args[0][0], [ "ss-stub", "--master", "spark://host:7077", "--jars", "jar1.jar,jar2.jar", "--py-files", "file1.py,file2.py", "--files", "file1,file2", "--archives", "archive1", "--conf", "prop1=val1", "test.py", ], ) @patch("luigi.contrib.external_program.logger") @patch("luigi.contrib.external_program.tempfile.TemporaryFile") @patch("luigi.contrib.external_program.subprocess.Popen") def test_handle_failed_job(self, proc, file, logger): proc.return_value.returncode = 1 file.return_value = BytesIO(b"spark test error") try: job = TestSparkSubmitTask() job.run() except ExternalProgramRunError as e: self.assertEqual(e.err, "spark test error") self.assertIn("spark test error", str(e)) self.assertIn(call.info("Program stderr:\nspark test error"), logger.mock_calls) else: self.fail("Should have thrown ExternalProgramRunError") @patch("luigi.contrib.external_program.logger") @patch("luigi.contrib.external_program.tempfile.TemporaryFile") @patch("luigi.contrib.external_program.subprocess.Popen") def test_dont_log_stderr_on_success(self, proc, file, logger): proc.return_value.returncode = 0 file.return_value = BytesIO(b"spark normal error output") job = TestSparkSubmitTask() job.run() self.assertNotIn(call.info("Program stderr:\nspark normal error output"), logger.mock_calls) @patch("luigi.contrib.external_program.subprocess.Popen") def test_app_must_be_set(self, proc): with self.assertRaises(NotImplementedError): job = SparkSubmitTask() job.run() @patch("luigi.contrib.external_program.subprocess.Popen") def test_app_interruption(self, proc): def interrupt(): raise KeyboardInterrupt() proc.return_value.wait = interrupt try: job = TestSparkSubmitTask() job.run() except KeyboardInterrupt: pass proc.return_value.kill.check_called() @with_config({"spark": {"deploy-mode": "client"}}) def test_tracking_url_is_found_in_stderr_client_mode(self): test_val = Value("i", 0) def fake_set_tracking_url(val, url): if url == "http://10.66.76.155:4040": val.value += 1 def Popen_wrap(args, **kwargs): return Popen('>&2 echo "INFO SparkUI: Bound SparkUI to 0.0.0.0, and started at http://10.66.76.155:4040"', shell=True, **kwargs) task = TestSparkSubmitTask() with mock.patch("luigi.contrib.external_program.subprocess.Popen", wraps=Popen_wrap): with mock.patch.object(task, "set_tracking_url", new=partial(fake_set_tracking_url, test_val)): task.run() self.assertEqual(test_val.value, 1) @with_config({"spark": {"deploy-mode": "cluster"}}) def test_tracking_url_is_found_in_stderr_cluster_mode(self): test_val = Value("i", 0) def fake_set_tracking_url(val, url): if url == "https://127.0.0.1:4040": val.value += 1 def Popen_wrap(args, **kwargs): return Popen('>&2 echo "tracking URL: https://127.0.0.1:4040"', shell=True, **kwargs) task = TestSparkSubmitTask() with mock.patch("luigi.contrib.external_program.subprocess.Popen", wraps=Popen_wrap): with mock.patch.object(task, "set_tracking_url", new=partial(fake_set_tracking_url, test_val)): task.run() self.assertEqual(test_val.value, 1) @pytest.mark.apache class PySparkTaskTest(unittest.TestCase): ss = "ss-stub" @with_config({"spark": {"spark-submit": ss, "master": "spark://host:7077", "deploy-mode": "client"}}) @patch("luigi.contrib.external_program.subprocess.Popen") def test_run(self, proc): setup_run_process(proc) job = TestPySparkTask() job.run() proc_arg_list = proc.call_args[0][0] self.assertEqual(proc_arg_list[0:7], ["ss-stub", "--master", "spark://host:7077", "--deploy-mode", "client", "--name", "TestPySparkTask"]) self.assertTrue(os.path.exists(proc_arg_list[7])) self.assertTrue(proc_arg_list[8].endswith("TestPySparkTask.pickle")) @with_config({"spark": {"spark-submit": ss, "master": "spark://host:7077", "deploy-mode": "client"}}) @patch("luigi.contrib.external_program.subprocess.Popen") def test_run_with_pickle_dump(self, proc): setup_run_process(proc) job = TestPySparkTask() luigi.build([job], local_scheduler=True) self.assertEqual(proc.call_count, 1) proc_arg_list = proc.call_args[0][0] self.assertEqual(proc_arg_list[0:7], ["ss-stub", "--master", "spark://host:7077", "--deploy-mode", "client", "--name", "TestPySparkTask"]) self.assertTrue(os.path.exists(proc_arg_list[7])) self.assertTrue(proc_arg_list[8].endswith("TestPySparkTask.pickle")) @with_config({"spark": {"spark-submit": ss, "master": "spark://host:7077", "deploy-mode": "cluster"}}) @patch("luigi.contrib.external_program.subprocess.Popen") def test_run_with_cluster(self, proc): setup_run_process(proc) job = TestPySparkTask() job.run() proc_arg_list = proc.call_args[0][0] self.assertEqual(proc_arg_list[0:8], ["ss-stub", "--master", "spark://host:7077", "--deploy-mode", "cluster", "--name", "TestPySparkTask", "--files"]) self.assertTrue(proc_arg_list[8].endswith("TestPySparkTask.pickle")) self.assertTrue(os.path.exists(proc_arg_list[9])) self.assertEqual("TestPySparkTask.pickle", proc_arg_list[10]) @patch.dict("sys.modules", {"pyspark": MagicMock()}) @patch("pyspark.SparkContext") def test_pyspark_runner(self, spark_context): sc = spark_context.return_value def mock_spark_submit(task): from luigi.contrib.pyspark_runner import PySparkRunner PySparkRunner(*task.app_command()[1:]).run() # Check py-package exists self.assertTrue(os.path.exists(sc.addPyFile.call_args[0][0])) # Check that main module containing the task exists. run_path = os.path.dirname(task.app_command()[1]) self.assertTrue(os.path.exists(os.path.join(run_path, os.path.basename(__file__)))) # Check that the python path contains the run_path self.assertTrue(run_path in sys.path) # Check if find_class finds the class for the correct module name. with open(task.app_command()[1], "rb") as fp: self.assertTrue(pickle.Unpickler(fp).find_class("spark_test", "TestPySparkTask")) with patch.object(SparkSubmitTask, "run", mock_spark_submit): job = TestPySparkTask() with temporary_unloaded_module(b"") as task_module: with_config({"spark": {"py-packages": task_module}})(job.run)() sc.textFile.assert_called_with("input") sc.textFile.return_value.saveAsTextFile.assert_called_with("output") sc.stop.assert_called_once_with() def test_pyspark_session_runner_use_spark_session_true(self): pyspark = MagicMock() pyspark.__version__ = "2.1.0" pyspark_sql = MagicMock() with patch.dict(sys.modules, {"pyspark": pyspark, "pyspark.sql": pyspark_sql}): spark = pyspark_sql.SparkSession.builder.config.return_value.enableHiveSupport.return_value.getOrCreate.return_value sc = spark.sparkContext def mock_spark_submit(task): from luigi.contrib.pyspark_runner import PySparkSessionRunner PySparkSessionRunner(*task.app_command()[1:]).run() # Check py-package exists self.assertTrue(os.path.exists(sc.addPyFile.call_args[0][0])) # Check that main module containing the task exists. run_path = os.path.dirname(task.app_command()[1]) self.assertTrue(os.path.exists(os.path.join(run_path, os.path.basename(__file__)))) # Check that the python path contains the run_path self.assertTrue(run_path in sys.path) # Check if find_class finds the class for the correct module name. with open(task.app_command()[1], "rb") as fp: self.assertTrue(pickle.Unpickler(fp).find_class("spark_test", "TestPySparkSessionTask")) with patch.object(SparkSubmitTask, "run", mock_spark_submit): job = TestPySparkSessionTask() with temporary_unloaded_module(b"") as task_module: with_config({"spark": {"py-packages": task_module}})(job.run)() spark.sql.assert_called_with("input") spark.sql.return_value.write.saveAsTable.assert_called_with("output") spark.stop.assert_called_once_with() def test_pyspark_session_runner_use_spark_session_true_spark1(self): pyspark = MagicMock() pyspark.__version__ = "1.6.3" pyspark_sql = MagicMock() with patch.dict(sys.modules, {"pyspark": pyspark, "pyspark.sql": pyspark_sql}): def mock_spark_submit(task): from luigi.contrib.pyspark_runner import PySparkSessionRunner self.assertRaises(RuntimeError, PySparkSessionRunner(*task.app_command()[1:]).run) with patch.object(SparkSubmitTask, "run", mock_spark_submit): job = TestPySparkSessionTask() with temporary_unloaded_module(b"") as task_module: with_config({"spark": {"py-packages": task_module}})(job.run)() @patch("luigi.contrib.external_program.subprocess.Popen") def test_name_cleanup(self, proc): setup_run_process(proc) job = MessyNamePySparkTask() job.run() assert "AppName_a_b_c_1_2_3_4_" in job.run_path ================================================ FILE: test/contrib/sqla_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright (c) 2015 Gouthaman Balaraman # # Licensed under the Apache License, Version 2.0 (the "License"); you may not # use this file except in compliance with the License. You may obtain a copy of # the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations under # the License. # """ This file implements unit test cases for luigi/contrib/sqla.py Author: Gouthaman Balaraman Date: 01/02/2015 """ import os import shutil import tempfile import unittest import pytest import sqlalchemy from helpers import skipOnTravisAndGithubActions import luigi from luigi.contrib import sqla from luigi.mock import MockTarget class BaseTask(luigi.Task): TASK_LIST = ["item%d\tproperty%d\n" % (i, i) for i in range(10)] def output(self): return MockTarget("BaseTask", mirror_on_stderr=True) def run(self): out = self.output().open("w") for task in self.TASK_LIST: out.write(task) out.close() @pytest.mark.contrib class TestSQLA(unittest.TestCase): NUM_WORKERS = 1 def _clear_tables(self): meta = sqlalchemy.MetaData() meta.reflect(bind=self.engine) for table in reversed(meta.sorted_tables): self.engine.execute(table.delete()) def setUp(self): self.tempdir = tempfile.mkdtemp() self.connection_string = self.get_connection_string() self.connect_args = {"timeout": 5.0} self.engine = sqlalchemy.create_engine(self.connection_string, connect_args=self.connect_args) # Create SQLATask and store in self class SQLATask(sqla.CopyToTable): columns = [(["item", sqlalchemy.String(64)], {}), (["property", sqlalchemy.String(64)], {})] connection_string = self.connection_string connect_args = self.connect_args table = "item_property" chunk_size = 1 def requires(self): return BaseTask() self.SQLATask = SQLATask def tearDown(self): self._clear_tables() if os.path.exists(self.tempdir): shutil.rmtree(self.tempdir) def get_connection_string(self, db="sqlatest.db"): return "sqlite:///{path}".format(path=os.path.join(self.tempdir, db)) def test_create_table(self): """ Test that this method creates table that we require :return: """ class TestSQLData(sqla.CopyToTable): connection_string = self.connection_string connect_args = self.connect_args table = "test_table" columns = [(["id", sqlalchemy.Integer], dict(primary_key=True)), (["name", sqlalchemy.String(64)], {}), (["value", sqlalchemy.String(64)], {})] chunk_size = 1 def output(self): pass sql_copy = TestSQLData() eng = sqlalchemy.create_engine(TestSQLData.connection_string) self.assertFalse(eng.dialect.has_table(eng.connect(), TestSQLData.table)) sql_copy.create_table(eng) self.assertTrue(eng.dialect.has_table(eng.connect(), TestSQLData.table)) # repeat and ensure it just binds to existing table sql_copy.create_table(eng) def test_create_table_raises_no_columns(self): """ Check that the test fails when the columns are not set :return: """ class TestSQLData(sqla.CopyToTable): connection_string = self.connection_string table = "test_table" columns = [] chunk_size = 1 def output(self): pass sql_copy = TestSQLData() eng = sqlalchemy.create_engine(TestSQLData.connection_string) self.assertRaises(NotImplementedError, sql_copy.create_table, eng) def _check_entries(self, engine): with engine.begin() as conn: meta = sqlalchemy.MetaData() meta.reflect(bind=engine) self.assertEqual({"table_updates", "item_property"}, set(meta.tables.keys())) table = meta.tables[self.SQLATask.table] s = sqlalchemy.select([sqlalchemy.func.count(table.c.item)]) result = conn.execute(s).fetchone() self.assertEqual(len(BaseTask.TASK_LIST), result[0]) s = sqlalchemy.select([table]).order_by(table.c.item) result = conn.execute(s).fetchall() for i in range(len(BaseTask.TASK_LIST)): given = BaseTask.TASK_LIST[i].strip("\n").split("\t") given = (str(given[0]), str(given[1])) self.assertEqual(given, tuple(result[i])) def test_rows(self): task, task0 = self.SQLATask(), BaseTask() luigi.build([task, task0], local_scheduler=True, workers=self.NUM_WORKERS) for i, row in enumerate(task.rows()): given = BaseTask.TASK_LIST[i].strip("\n").split("\t") self.assertEqual(row, given) def test_run(self): """ Checking that the runs go as expected. Rerunning the same shouldn't end up inserting more rows into the db. :return: """ task, task0 = self.SQLATask(), BaseTask() self.engine = sqlalchemy.create_engine(task.connection_string) luigi.build([task0, task], local_scheduler=True) self._check_entries(self.engine) # rerun and the num entries should be the same luigi.build([task0, task], local_scheduler=True, workers=self.NUM_WORKERS) self._check_entries(self.engine) def test_run_with_chunk_size(self): """ The chunk_size can be specified in order to control the batch size for inserts. :return: """ task, task0 = self.SQLATask(), BaseTask() self.engine = sqlalchemy.create_engine(task.connection_string) task.chunk_size = 2 # change chunk size and check it runs ok luigi.build([task, task0], local_scheduler=True, workers=self.NUM_WORKERS) self._check_entries(self.engine) def test_reflect(self): """ If the table is setup already, then one can set reflect to True, and completely skip the columns part. It is not even required at that point. :return: """ SQLATask = self.SQLATask class AnotherSQLATask(sqla.CopyToTable): connection_string = self.connection_string table = "item_property" reflect = True chunk_size = 1 def requires(self): return SQLATask() def copy(self, conn, ins_rows, table_bound): ins = ( table_bound.update() .where(table_bound.c.property == sqlalchemy.bindparam("_property")) .values({table_bound.c.item: sqlalchemy.bindparam("_item")}) ) conn.execute(ins, ins_rows) def rows(self): for line in BaseTask.TASK_LIST: yield line.strip("\n").split("\t") task0, task1, task2 = AnotherSQLATask(), self.SQLATask(), BaseTask() luigi.build([task0, task1, task2], local_scheduler=True, workers=self.NUM_WORKERS) self._check_entries(self.engine) def test_create_marker_table(self): """ Is the marker table created as expected for the SQLAlchemyTarget :return: """ target = sqla.SQLAlchemyTarget(self.connection_string, "test_table", "12312123") target.create_marker_table() self.assertTrue(target.engine.dialect.has_table(target.engine.connect(), target.marker_table)) def test_touch(self): """ Touch takes care of creating a checkpoint for task completion :return: """ target = sqla.SQLAlchemyTarget(self.connection_string, "test_table", "12312123") target.create_marker_table() self.assertFalse(target.exists()) target.touch() self.assertTrue(target.exists()) def test_row_overload(self): """Overload the rows method and we should be able to insert data into database""" class SQLARowOverloadTest(sqla.CopyToTable): columns = [(["item", sqlalchemy.String(64)], {}), (["property", sqlalchemy.String(64)], {})] connection_string = self.connection_string table = "item_property" chunk_size = 1 def rows(self): tasks = [ ("item0", "property0"), ("item1", "property1"), ("item2", "property2"), ("item3", "property3"), ("item4", "property4"), ("item5", "property5"), ("item6", "property6"), ("item7", "property7"), ("item8", "property8"), ("item9", "property9"), ] for row in tasks: yield row task = SQLARowOverloadTest() luigi.build([task], local_scheduler=True, workers=self.NUM_WORKERS) self._check_entries(self.engine) def test_column_row_separator(self): """ Test alternate column row separator works :return: """ class ModBaseTask(luigi.Task): def output(self): return MockTarget("ModBaseTask", mirror_on_stderr=True) def run(self): out = self.output().open("w") tasks = ["item%d,property%d\n" % (i, i) for i in range(10)] for task in tasks: out.write(task) out.close() class ModSQLATask(sqla.CopyToTable): columns = [(["item", sqlalchemy.String(64)], {}), (["property", sqlalchemy.String(64)], {})] connection_string = self.connection_string table = "item_property" column_separator = "," chunk_size = 1 def requires(self): return ModBaseTask() task1, task2 = ModBaseTask(), ModSQLATask() luigi.build([task1, task2], local_scheduler=True, workers=self.NUM_WORKERS) self._check_entries(self.engine) def test_update_rows_test(self): """ Overload the copy() method and implement an update action. :return: """ class ModBaseTask(luigi.Task): def output(self): return MockTarget("BaseTask", mirror_on_stderr=True) def run(self): out = self.output().open("w") for task in self.TASK_LIST: out.write("dummy_" + task) out.close() class ModSQLATask(sqla.CopyToTable): connection_string = self.connection_string table = "item_property" columns = [(["item", sqlalchemy.String(64)], {}), (["property", sqlalchemy.String(64)], {})] chunk_size = 1 def requires(self): return ModBaseTask() class UpdateSQLATask(sqla.CopyToTable): connection_string = self.connection_string table = "item_property" reflect = True chunk_size = 1 def requires(self): return ModSQLATask() def copy(self, conn, ins_rows, table_bound): ins = ( table_bound.update() .where(table_bound.c.property == sqlalchemy.bindparam("_property")) .values({table_bound.c.item: sqlalchemy.bindparam("_item")}) ) conn.execute(ins, ins_rows) def rows(self): for task in self.TASK_LIST: yield task.strip("\n").split("\t") # Running only task1, and task2 should fail task1, task2, task3 = ModBaseTask(), ModSQLATask(), UpdateSQLATask() luigi.build([task1, task2, task3], local_scheduler=True, workers=self.NUM_WORKERS) self._check_entries(self.engine) @skipOnTravisAndGithubActions("AssertionError: 10 != 7; https://travis-ci.org/spotify/luigi/jobs/156732446") def test_multiple_tasks(self): """ Test a case where there are multiple tasks :return: """ class SmallSQLATask(sqla.CopyToTable): item = luigi.Parameter() property = luigi.Parameter() columns = [(["item", sqlalchemy.String(64)], {}), (["property", sqlalchemy.String(64)], {})] connection_string = self.connection_string table = "item_property" chunk_size = 1 def rows(self): yield (self.item, self.property) class ManyBaseTask(luigi.Task): def requires(self): for t in BaseTask.TASK_LIST: item, property = t.strip().split("\t") yield SmallSQLATask(item=item, property=property) task2 = ManyBaseTask() luigi.build([task2], local_scheduler=True, workers=self.NUM_WORKERS) self._check_entries(self.engine) def test_multiple_engines(self): """ Test case where different tasks require different SQL engines. """ alt_db = self.get_connection_string("sqlatest2.db") class MultiEngineTask(self.SQLATask): connection_string = alt_db task0, task1, task2 = BaseTask(), self.SQLATask(), MultiEngineTask() self.assertTrue(task1.output().engine != task2.output().engine) luigi.build([task2, task1, task0], local_scheduler=True, workers=self.NUM_WORKERS) self._check_entries(task1.output().engine) self._check_entries(task2.output().engine) @pytest.mark.contrib class TestSQLA2(TestSQLA): """2 workers version""" NUM_WORKERS = 2 ================================================ FILE: test/contrib/streaming_test.py ================================================ import os import unittest import mock import pytest from luigi import Parameter from luigi.contrib import mrrunner from luigi.contrib.hadoop import HadoopJobRunner, JobTask from luigi.contrib.hdfs import HdfsTarget class MockStreamingJob(JobTask): package_binary = Parameter(default=None) def output(self): rv = mock.MagicMock(HdfsTarget) rv.path = "test_path" return rv class MockStreamingJobWithExtraArguments(JobTask): package_binary = Parameter(default=None) def extra_streaming_arguments(self): return [("myargument", "/path/to/coolvalue")] def extra_archives(self): return ["/path/to/myarchive.zip", "/path/to/other_archive.zip"] def output(self): rv = mock.MagicMock(HdfsTarget) rv.path = "test_path" return rv @pytest.mark.apache class StreamingRunTest(unittest.TestCase): @mock.patch("luigi.contrib.hadoop.shutil") @mock.patch("luigi.contrib.hadoop.run_and_track_hadoop_job") def test_package_binary_run(self, rath_job, shutil): job_runner = HadoopJobRunner("jar_path", end_job_with_atomic_move_dir=False) job_runner.run_job(MockStreamingJob(package_binary="test_bin.pex")) self.assertEqual(1, shutil.copy.call_count) pex_src, pex_dest = shutil.copy.call_args[0] runner_fname = os.path.basename(pex_dest) self.assertEqual("test_bin.pex", pex_src) self.assertEqual("mrrunner.pex", runner_fname) self.assertEqual(1, rath_job.call_count) mr_args = rath_job.call_args[0][0] mr_args_pairs = zip(mr_args, mr_args[1:]) self.assertIn(("-mapper", "python mrrunner.pex map"), mr_args_pairs) self.assertIn(("-file", pex_dest), mr_args_pairs) @mock.patch("luigi.contrib.hadoop.create_packages_archive") @mock.patch("luigi.contrib.hadoop.run_and_track_hadoop_job") def test_standard_run(self, rath_job, cpa): job_runner = HadoopJobRunner("jar_path", end_job_with_atomic_move_dir=False) job_runner.run_job(MockStreamingJob()) self.assertEqual(1, cpa.call_count) self.assertEqual(1, rath_job.call_count) mr_args = rath_job.call_args[0][0] mr_args_pairs = zip(mr_args, mr_args[1:]) self.assertIn(("-mapper", "python mrrunner.py map"), mr_args_pairs) self.assertIn(("-file", mrrunner.__file__.rstrip("c")), mr_args_pairs) @mock.patch("luigi.contrib.hadoop.create_packages_archive") @mock.patch("luigi.contrib.hadoop.run_and_track_hadoop_job") def test_run_with_extra_arguments(self, rath_job, cpa): job_runner = HadoopJobRunner("jar_path", end_job_with_atomic_move_dir=False) job_runner.run_job(MockStreamingJobWithExtraArguments()) self.assertEqual(1, cpa.call_count) self.assertEqual(1, rath_job.call_count) mr_args = rath_job.call_args[0][0] mr_args_pairs = list(zip(mr_args, mr_args[1:])) self.assertIn(("-myargument", "/path/to/coolvalue"), mr_args_pairs) self.assertIn(("-archives", "/path/to/myarchive.zip,/path/to/other_archive.zip"), mr_args_pairs) ================================================ FILE: test/contrib/test_ssh.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ Integration tests for ssh module. """ import os import random import socket import subprocess import pytest import target_test from helpers import unittest from luigi.contrib.ssh import RemoteCalledProcessError, RemoteContext, RemoteFileSystem, RemoteTarget from luigi.target import FileAlreadyExists, MissingParentDirectory working_ssh_host = os.environ.get("SSH_TEST_HOST", "localhost") # set this to a working ssh host string (e.g. "localhost") to activate integration tests # The following tests require a working ssh server at `working_ssh_host` # the test runner can ssh into using password-less authentication # since `nc` has different syntax on different platforms # we use a short python command to start # a 'hello'-server on the remote machine HELLO_SERVER_CMD = """ import socket, sys listener = socket.socket() listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) listener.bind(('localhost', 2134)) listener.listen(1) sys.stdout.write('ready') sys.stdout.flush() conn = listener.accept()[0] conn.sendall(b'hello') """ try: x = subprocess.check_output("ssh %s -S none -o BatchMode=yes 'echo 1'" % working_ssh_host, shell=True) if x != b"1\n": raise unittest.SkipTest("Not able to connect to ssh server") except Exception: raise unittest.SkipTest("Not able to connect to ssh server") @pytest.mark.contrib class TestRemoteContext(unittest.TestCase): def setUp(self): self.context = RemoteContext(working_ssh_host) def tearDown(self): try: self.remote_server_handle.terminate() except Exception: pass def test_check_output(self): """Test check_output ssh Assumes the running user can ssh to working_ssh_host """ output = self.context.check_output(["echo", "-n", "luigi"]) self.assertEqual(output, b"luigi") def test_tunnel(self): print("Setting up remote listener...") self.remote_server_handle = self.context.Popen(["python", "-c", '"{0}"'.format(HELLO_SERVER_CMD)], stdout=subprocess.PIPE) print("Setting up tunnel") with self.context.tunnel(2135, 2134): print("Tunnel up!") # hack to make sure the listener process is up # and running before we write to it server_output = self.remote_server_handle.stdout.read(5) self.assertEqual(server_output, b"ready") print("Connecting to server via tunnel") s = socket.socket() s.connect(("localhost", 2135)) print( "Receiving...", ) response = s.recv(5) self.assertEqual(response, b"hello") print("Closing connection") s.close() print("Waiting for listener...") output, _ = self.remote_server_handle.communicate() self.assertEqual(self.remote_server_handle.returncode, 0) print("Closing tunnel") @pytest.mark.contrib class TestRemoteTarget(unittest.TestCase): """These tests assume RemoteContext working in order for setUp and tearDown to work """ def setUp(self): self.ctx = RemoteContext(working_ssh_host) self.filepath = "/tmp/luigi_remote_test.dat" self.target = RemoteTarget( self.filepath, working_ssh_host, ) self.ctx.check_output(["rm", "-rf", self.filepath]) self.ctx.check_output(["echo -n 'hello' >", self.filepath]) def tearDown(self): self.ctx.check_output(["rm", "-rf", self.filepath]) def test_exists(self): self.assertTrue(self.target.exists()) no_file = RemoteTarget( "/tmp/_file_that_doesnt_exist_", working_ssh_host, ) self.assertFalse(no_file.exists()) def test_remove(self): self.target.remove() self.assertRaises(subprocess.CalledProcessError, self.ctx.check_output, ["cat", self.filepath]) def test_open(self): f = self.target.open("r") file_content = f.read() f.close() self.assertEqual(file_content, "hello") self.assertTrue(self.target.fs.exists(self.filepath)) self.assertFalse(self.target.fs.isdir(self.filepath)) def test_context_manager(self): with self.target.open("r") as f: file_content = f.read() self.assertEqual(file_content, "hello") @pytest.mark.contrib class TestRemoteFilesystem(unittest.TestCase): def setUp(self): self.fs = RemoteFileSystem(working_ssh_host) self.root = "/tmp/luigi-remote-test" self.directory = self.root + "/dir" self.filepath = self.directory + "/file" self.target = RemoteTarget( self.filepath, working_ssh_host, ) self.fs.remote_context.check_output(["rm", "-rf", self.root]) self.addCleanup(self.fs.remote_context.check_output, ["rm", "-rf", self.root]) def test_mkdir(self): self.assertFalse(self.fs.isdir(self.directory)) self.assertRaises(MissingParentDirectory, self.fs.mkdir, self.directory, parents=False) self.fs.mkdir(self.directory) self.assertTrue(self.fs.isdir(self.directory)) # Shouldn't throw self.fs.mkdir(self.directory) self.assertRaises(FileAlreadyExists, self.fs.mkdir, self.directory, raise_if_exists=True) def test_list(self): with self.target.open("w"): pass self.assertEqual([self.target.path], list(self.fs.listdir(self.directory))) @pytest.mark.contrib class TestGetAttrRecursion(unittest.TestCase): def test_recursion_on_delete(self): target = RemoteTarget("/etc/this/does/not/exist", working_ssh_host) with self.assertRaises(RemoteCalledProcessError): with target.open("w") as fh: fh.write("test") @pytest.mark.contrib class TestRemoteTargetAtomicity(unittest.TestCase, target_test.FileSystemTargetTestMixin): path = "/tmp/luigi_remote_atomic_test.txt" ctx = RemoteContext(working_ssh_host) def create_target(self, format=None): return RemoteTarget(self.path, working_ssh_host, format=format) def _exists(self, path): try: self.ctx.check_output(["test", "-e", path]) except subprocess.CalledProcessError as e: if e.returncode == 1: return False else: raise return True def assertCleanUp(self, tp): self.assertFalse(self._exists(tp)) def setUp(self): self.ctx.check_output(["rm", "-rf", self.path]) self.local_file = "/tmp/local_luigi_remote_atomic_test.txt" if os.path.exists(self.local_file): os.remove(self.local_file) def tearDown(self): self.ctx.check_output(["rm", "-rf", self.path]) if os.path.exists(self.local_file): os.remove(self.local_file) def test_put(self): f = open(self.local_file, "w") f.write("hello") f.close() t = RemoteTarget(self.path, working_ssh_host) t.put(self.local_file) self.assertTrue(self._exists(self.path)) def test_get(self): self.ctx.check_output(["echo -n 'hello' >", self.path]) t = RemoteTarget(self.path, working_ssh_host) t.get(self.local_file) f = open(self.local_file, "r") file_content = f.read() self.assertEqual(file_content, "hello") test_move_on_fs = None # ssh don't have move (yet?) test_rename_dont_move_on_fs = None # ssh don't have move (yet?) class TestRemoteTargetCreateDirectories(TestRemoteTargetAtomicity): path = "/tmp/%s/xyz/luigi_remote_atomic_test.txt" % random.randint(0, 999999999) class TestRemoteTargetRelative(TestRemoteTargetAtomicity): path = "luigi_remote_atomic_test.txt" ================================================ FILE: test/create_packages_archive_root/module.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ================================================ FILE: test/create_packages_archive_root/package/__init__.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ================================================ FILE: test/create_packages_archive_root/package/submodule.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import os # NOQA ================================================ FILE: test/create_packages_archive_root/package/submodule_with_absolute_import.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import os # NOQA ================================================ FILE: test/create_packages_archive_root/package/submodule_without_imports.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ================================================ FILE: test/create_packages_archive_root/package/subpackage/__init__.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ================================================ FILE: test/create_packages_archive_root/package/subpackage/submodule.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import os # NOQA ================================================ FILE: test/custom_metrics_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2017 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import os import tempfile import time from helpers import LuigiTestCase, temporary_unloaded_module import luigi from luigi.metrics import MetricsCollectors from luigi.scheduler import Scheduler from luigi.worker import Worker class CustomMetricsTestMyTask(luigi.Task): root_path = luigi.PathParameter() n = luigi.IntParameter() def output(self): basename = "%s_%s.txt" % (self.__class__.__name__, self.n) return luigi.LocalTarget(os.path.join(self.root_path, basename)) def run(self): time.sleep(self.n) with self.output().open("w") as f: f.write("content\n") class CustomMetricsTestWrapper(CustomMetricsTestMyTask): def requires(self): return [self.clone(CustomMetricsTestMyTask, n=n) for n in range(self.n)] METRICS_COLLECTOR_MODULE = b""" from luigi.metrics import NoMetricsCollector class CustomMetricsCollector(NoMetricsCollector): def __init__(self, *args, **kwargs): super(CustomMetricsCollector, self).__init__(*args, **kwargs) self.elapsed = {} def handle_task_statistics(self, task, statistics): if "elapsed" in statistics: self.elapsed[(task.family, task.params.get("n"))] = statistics["elapsed"] """ TASK_CONTEXT_MODULE = b""" import time class CustomTaskContext: def __init__(self, task_process): self._task_process = task_process self._start = None def __enter__(self): self._start = time.perf_counter() return self def __exit__(self, exc_type, exc_val, exc_tb): assert self._start is not None elapsed = time.perf_counter() - self._start self._task_process.status_reporter.report_task_statistics({"elapsed": elapsed}) """ class CustomMetricsTest(LuigiTestCase): """ Test showcasing collection of cutom metrics """ def _run_task_on_worker(self, worker): with tempfile.TemporaryDirectory() as tmpdir: task = CustomMetricsTestWrapper(n=3, root_path=tmpdir) self.assertTrue(worker.add(task)) worker.run() self.assertTrue(task.complete()) def _create_worker_and_run_task(self, scheduler): with temporary_unloaded_module(TASK_CONTEXT_MODULE) as task_context_module: with Worker(scheduler=scheduler, worker_id="X", task_process_context=task_context_module + ".CustomTaskContext") as worker: self._run_task_on_worker(worker) def test_custom_metrics(self): with temporary_unloaded_module(METRICS_COLLECTOR_MODULE) as metrics_collector_module: scheduler = Scheduler(metrics_collector=MetricsCollectors.custom, metrics_custom_import=metrics_collector_module + ".CustomMetricsCollector") self._create_worker_and_run_task(scheduler) for (family, n), elapsed in scheduler._state._metrics_collector.elapsed.items(): self.assertTrue(family in {"CustomMetricsTestMyTask", "CustomMetricsTestWrapper"}) self.assertTrue(elapsed >= float(n)) ================================================ FILE: test/customized_run_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import logging import time from helpers import unittest import luigi import luigi.contrib.hadoop import luigi.rpc import luigi.scheduler import luigi.worker class DummyTask(luigi.Task): task_namespace = "customized_run" # to prevent task name coflict between tests n = luigi.Parameter() def __init__(self, *args, **kwargs): super(DummyTask, self).__init__(*args, **kwargs) self.has_run = False def complete(self): return self.has_run def run(self): logging.debug("%s - setting has_run", self) self.has_run = True class CustomizedLocalScheduler(luigi.scheduler.Scheduler): def __init__(self, *args, **kwargs): super(CustomizedLocalScheduler, self).__init__(*args, **kwargs) self.has_run = False def get_work(self, worker, host=None, **kwargs): r = super(CustomizedLocalScheduler, self).get_work(worker=worker, host=host) self.has_run = True return r def complete(self): return self.has_run class CustomizedRemoteScheduler(luigi.rpc.RemoteScheduler): def __init__(self, *args, **kwargs): super(CustomizedRemoteScheduler, self).__init__(*args, **kwargs) self.has_run = False def get_work(self, worker, host=None): r = super(CustomizedRemoteScheduler, self).get_work(worker=worker, host=host) self.has_run = True return r def complete(self): return self.has_run class CustomizedWorker(luigi.worker.Worker): def __init__(self, *args, **kwargs): super(CustomizedWorker, self).__init__(*args, **kwargs) self.has_run = False def _run_task(self, task_id): super(CustomizedWorker, self)._run_task(task_id) self.has_run = True def complete(self): return self.has_run class CustomizedWorkerSchedulerFactory: def __init__(self, *args, **kwargs): self.scheduler = CustomizedLocalScheduler() self.worker = CustomizedWorker(self.scheduler) def create_local_scheduler(self): return self.scheduler def create_remote_scheduler(self, url): return CustomizedRemoteScheduler(url) def create_worker(self, scheduler, worker_processes=None, assistant=False): return self.worker class CustomizedWorkerTest(unittest.TestCase): """Test that luigi's build method (and ultimately the run method) can accept a customized worker and scheduler""" def setUp(self): self.worker_scheduler_factory = CustomizedWorkerSchedulerFactory() self.time = time.time def tearDown(self): if time.time != self.time: time.time = self.time def setTime(self, t): time.time = lambda: t def test_customized_worker(self): a = DummyTask(3) self.assertFalse(a.complete()) self.assertFalse(self.worker_scheduler_factory.worker.complete()) luigi.build([a], worker_scheduler_factory=self.worker_scheduler_factory) self.assertTrue(a.complete()) self.assertTrue(self.worker_scheduler_factory.worker.complete()) def test_cmdline_custom_worker(self): self.assertFalse(self.worker_scheduler_factory.worker.complete()) luigi.run(["customized_run.DummyTask", "--n", "4"], worker_scheduler_factory=self.worker_scheduler_factory) self.assertTrue(self.worker_scheduler_factory.worker.complete()) ================================================ FILE: test/date_interval_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import datetime from helpers import LuigiTestCase, in_parse import luigi from luigi.parameter import DateIntervalParameter as DI class DateIntervalTest(LuigiTestCase): def test_date(self): di = DI().parse("2012-01-01") self.assertEqual(di.dates(), [datetime.date(2012, 1, 1)]) self.assertEqual(di.next().dates(), [datetime.date(2012, 1, 2)]) self.assertEqual(di.prev().dates(), [datetime.date(2011, 12, 31)]) self.assertEqual(str(di), "2012-01-01") def test_month(self): di = DI().parse("2012-01") self.assertEqual(di.dates(), [datetime.date(2012, 1, 1) + datetime.timedelta(i) for i in range(31)]) self.assertEqual(di.next().dates(), [datetime.date(2012, 2, 1) + datetime.timedelta(i) for i in range(29)]) self.assertEqual(di.prev().dates(), [datetime.date(2011, 12, 1) + datetime.timedelta(i) for i in range(31)]) self.assertEqual(str(di), "2012-01") def test_year(self): di = DI().parse("2012") self.assertEqual(di.dates(), [datetime.date(2012, 1, 1) + datetime.timedelta(i) for i in range(366)]) self.assertEqual(di.next().dates(), [datetime.date(2013, 1, 1) + datetime.timedelta(i) for i in range(365)]) self.assertEqual(di.prev().dates(), [datetime.date(2011, 1, 1) + datetime.timedelta(i) for i in range(365)]) self.assertEqual(str(di), "2012") def test_week(self): # >>> datetime.date(2012, 1, 1).isocalendar() # (2011, 52, 7) # >>> datetime.date(2012, 12, 31).isocalendar() # (2013, 1, 1) di = DI().parse("2011-W52") self.assertEqual(di.dates(), [datetime.date(2011, 12, 26) + datetime.timedelta(i) for i in range(7)]) self.assertEqual(di.next().dates(), [datetime.date(2012, 1, 2) + datetime.timedelta(i) for i in range(7)]) self.assertEqual(str(di), "2011-W52") di = DI().parse("2013-W01") self.assertEqual(di.dates(), [datetime.date(2012, 12, 31) + datetime.timedelta(i) for i in range(7)]) self.assertEqual(di.prev().dates(), [datetime.date(2012, 12, 24) + datetime.timedelta(i) for i in range(7)]) self.assertEqual(str(di), "2013-W01") def test_interval(self): di = DI().parse("2012-01-01-2012-02-01") self.assertEqual(di.dates(), [datetime.date(2012, 1, 1) + datetime.timedelta(i) for i in range(31)]) self.assertRaises(NotImplementedError, di.next) self.assertRaises(NotImplementedError, di.prev) self.assertEqual(di.to_string(), "2012-01-01-2012-02-01") def test_exception(self): self.assertRaises(ValueError, DI().parse, "xyz") def test_comparison(self): a = DI().parse("2011") b = DI().parse("2013") c = DI().parse("2012") self.assertTrue(a < b) self.assertTrue(a < c) self.assertTrue(b > c) d = DI().parse("2012") self.assertTrue(d == c) self.assertEqual(d, min(c, b)) self.assertEqual(3, len({a, b, c, d})) def test_comparison_different_types(self): x = DI().parse("2012") y = DI().parse("2012-01-01-2013-01-01") self.assertRaises(TypeError, lambda: x == y) def test_parameter_parse_and_default(self): month = luigi.date_interval.Month(2012, 11) other = luigi.date_interval.Month(2012, 10) class MyTask(luigi.Task): di = DI(default=month) class MyTaskNoDefault(luigi.Task): di = DI() self.assertEqual(MyTask().di, month) in_parse(["MyTask", "--di", "2012-10"], lambda task: self.assertEqual(task.di, other)) task = MyTask(month) self.assertEqual(task.di, month) task = MyTask(di=month) self.assertEqual(task.di, month) task = MyTask(other) self.assertNotEqual(task.di, month) def fail1(): return MyTaskNoDefault() self.assertRaises(luigi.parameter.MissingParameterException, fail1) in_parse(["MyTaskNoDefault", "--di", "2012-10"], lambda task: self.assertEqual(task.di, other)) def test_hours(self): d = DI().parse("2015") self.assertEqual(len(list(d.hours())), 24 * 365) def test_cmp(self): operators = [lambda x, y: x == y, lambda x, y: x != y, lambda x, y: x < y, lambda x, y: x > y, lambda x, y: x <= y, lambda x, y: x >= y] dates = [ (1, 30, DI().parse("2015-01-01-2015-01-30")), (1, 15, DI().parse("2015-01-01-2015-01-15")), (10, 20, DI().parse("2015-01-10-2015-01-20")), (20, 30, DI().parse("2015-01-20-2015-01-30")), ] for from_a, to_a, di_a in dates: for from_b, to_b, di_b in dates: for op in operators: self.assertEqual(op((from_a, to_a), (from_b, to_b)), op(di_a, di_b)) ================================================ FILE: test/date_parameter_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import datetime from helpers import in_parse, unittest import luigi import luigi.interface class DateTask(luigi.Task): day = luigi.DateParameter() class DateHourTask(luigi.Task): dh = luigi.DateHourParameter() class DateMinuteTask(luigi.Task): dm = luigi.DateMinuteParameter() class DateSecondTask(luigi.Task): ds = luigi.DateSecondParameter() class MonthTask(luigi.Task): month = luigi.MonthParameter() class YearTask(luigi.Task): year = luigi.YearParameter() class DateParameterTest(unittest.TestCase): def test_parse(self): d = luigi.DateParameter().parse("2015-04-03") self.assertEqual(d, datetime.date(2015, 4, 3)) def test_serialize(self): d = luigi.DateParameter().serialize(datetime.date(2015, 4, 3)) self.assertEqual(d, "2015-04-03") def test_parse_interface(self): in_parse(["DateTask", "--day", "2015-04-03"], lambda task: self.assertEqual(task.day, datetime.date(2015, 4, 3))) def test_serialize_task(self): t = DateTask(datetime.date(2015, 4, 3)) self.assertEqual(str(t), "DateTask(day=2015-04-03)") class DateHourParameterTest(unittest.TestCase): def test_parse(self): dh = luigi.DateHourParameter().parse("2013-02-01T18") self.assertEqual(dh, datetime.datetime(2013, 2, 1, 18, 0, 0)) def test_date_to_dh(self): date = luigi.DateHourParameter().normalize(datetime.date(2000, 1, 1)) self.assertEqual(date, datetime.datetime(2000, 1, 1, 0)) def test_serialize(self): dh = luigi.DateHourParameter().serialize(datetime.datetime(2013, 2, 1, 18, 0, 0)) self.assertEqual(dh, "2013-02-01T18") def test_parse_interface(self): in_parse(["DateHourTask", "--dh", "2013-02-01T18"], lambda task: self.assertEqual(task.dh, datetime.datetime(2013, 2, 1, 18, 0, 0))) def test_serialize_task(self): t = DateHourTask(datetime.datetime(2013, 2, 1, 18, 0, 0)) self.assertEqual(str(t), "DateHourTask(dh=2013-02-01T18)") class DateMinuteParameterTest(unittest.TestCase): def test_parse(self): dm = luigi.DateMinuteParameter().parse("2013-02-01T1842") self.assertEqual(dm, datetime.datetime(2013, 2, 1, 18, 42, 0)) def test_parse_padding_zero(self): dm = luigi.DateMinuteParameter().parse("2013-02-01T1807") self.assertEqual(dm, datetime.datetime(2013, 2, 1, 18, 7, 0)) def test_parse_deprecated(self): with self.assertWarnsRegex(DeprecationWarning, 'Using "H" between hours and minutes is deprecated, omit it instead.'): dm = luigi.DateMinuteParameter().parse("2013-02-01T18H42") self.assertEqual(dm, datetime.datetime(2013, 2, 1, 18, 42, 0)) def test_serialize(self): dm = luigi.DateMinuteParameter().serialize(datetime.datetime(2013, 2, 1, 18, 42, 0)) self.assertEqual(dm, "2013-02-01T1842") def test_serialize_padding_zero(self): dm = luigi.DateMinuteParameter().serialize(datetime.datetime(2013, 2, 1, 18, 7, 0)) self.assertEqual(dm, "2013-02-01T1807") def test_parse_interface(self): in_parse(["DateMinuteTask", "--dm", "2013-02-01T1842"], lambda task: self.assertEqual(task.dm, datetime.datetime(2013, 2, 1, 18, 42, 0))) def test_serialize_task(self): t = DateMinuteTask(datetime.datetime(2013, 2, 1, 18, 42, 0)) self.assertEqual(str(t), "DateMinuteTask(dm=2013-02-01T1842)") class DateSecondParameterTest(unittest.TestCase): def test_parse(self): ds = luigi.DateSecondParameter().parse("2013-02-01T184227") self.assertEqual(ds, datetime.datetime(2013, 2, 1, 18, 42, 27)) def test_serialize(self): ds = luigi.DateSecondParameter().serialize(datetime.datetime(2013, 2, 1, 18, 42, 27)) self.assertEqual(ds, "2013-02-01T184227") def test_parse_interface(self): in_parse(["DateSecondTask", "--ds", "2013-02-01T184227"], lambda task: self.assertEqual(task.ds, datetime.datetime(2013, 2, 1, 18, 42, 27))) def test_serialize_task(self): t = DateSecondTask(datetime.datetime(2013, 2, 1, 18, 42, 27)) self.assertEqual(str(t), "DateSecondTask(ds=2013-02-01T184227)") class MonthParameterTest(unittest.TestCase): def test_parse(self): m = luigi.MonthParameter().parse("2015-04") self.assertEqual(m, datetime.date(2015, 4, 1)) def test_construct_month_interval(self): m = MonthTask(luigi.date_interval.Month(2015, 4)) self.assertEqual(m.month, datetime.date(2015, 4, 1)) def test_month_interval_default(self): class MonthDefaultTask(luigi.task.Task): month = luigi.MonthParameter(default=luigi.date_interval.Month(2015, 4)) m = MonthDefaultTask() self.assertEqual(m.month, datetime.date(2015, 4, 1)) def test_serialize(self): m = luigi.MonthParameter().serialize(datetime.date(2015, 4, 3)) self.assertEqual(m, "2015-04") def test_parse_interface(self): in_parse(["MonthTask", "--month", "2015-04"], lambda task: self.assertEqual(task.month, datetime.date(2015, 4, 1))) def test_serialize_task(self): task = MonthTask(datetime.date(2015, 4, 3)) self.assertEqual(str(task), "MonthTask(month=2015-04)") class YearParameterTest(unittest.TestCase): def test_parse(self): year = luigi.YearParameter().parse("2015") self.assertEqual(year, datetime.date(2015, 1, 1)) def test_construct_year_interval(self): y = YearTask(luigi.date_interval.Year(2015)) self.assertEqual(y.year, datetime.date(2015, 1, 1)) def test_year_interval_default(self): class YearDefaultTask(luigi.task.Task): year = luigi.YearParameter(default=luigi.date_interval.Year(2015)) m = YearDefaultTask() self.assertEqual(m.year, datetime.date(2015, 1, 1)) def test_serialize(self): year = luigi.YearParameter().serialize(datetime.date(2015, 4, 3)) self.assertEqual(year, "2015") def test_parse_interface(self): in_parse(["YearTask", "--year", "2015"], lambda task: self.assertEqual(task.year, datetime.date(2015, 1, 1))) def test_serialize_task(self): task = YearTask(datetime.date(2015, 4, 3)) self.assertEqual(str(task), "YearTask(year=2015)") ================================================ FILE: test/db_task_history_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from helpers import unittest, with_config import luigi import luigi.scheduler from luigi.db_task_history import DbTaskHistory from luigi.parameter import ParameterVisibility from luigi.task_status import DONE, PENDING, RUNNING class DummyTask(luigi.Task): foo = luigi.Parameter(default="foo") class ParamTask(luigi.Task): param1 = luigi.Parameter() param2 = luigi.IntParameter(visibility=ParameterVisibility.HIDDEN) param3 = luigi.Parameter(default="empty", visibility=ParameterVisibility.PRIVATE) class DbTaskHistoryTest(unittest.TestCase): @with_config(dict(task_history=dict(db_connection="sqlite:///:memory:"))) def setUp(self): self.history = DbTaskHistory() def test_task_list(self): self.run_task(DummyTask()) self.run_task(DummyTask(foo="bar")) with self.history._session() as session: tasks = list(self.history.find_all_by_name("DummyTask", session)) self.assertEqual(len(tasks), 2) for task in tasks: self.assertEqual(task.name, "DummyTask") self.assertEqual(task.host, "hostname") def test_task_events(self): self.run_task(DummyTask()) with self.history._session() as session: tasks = list(self.history.find_all_by_name("DummyTask", session)) self.assertEqual(len(tasks), 1) [task] = tasks self.assertEqual(task.name, "DummyTask") self.assertEqual(len(task.events), 3) for event, name in zip(task.events, [DONE, RUNNING, PENDING]): self.assertEqual(event.event_name, name) def test_task_by_params(self): task1 = ParamTask("foo", "bar") task2 = ParamTask("bar", "foo") with self.history._session() as session: self.run_task(task1) self.run_task(task2) task1_record = self.history.find_all_by_parameters(task_name="ParamTask", session=session, param1="foo", param2="bar") task2_record = self.history.find_all_by_parameters(task_name="ParamTask", session=session, param1="bar", param2="foo") for task, records in zip((task1, task2), (task1_record, task2_record)): records = list(records) self.assertEqual(len(records), 1) [record] = records self.assertEqual(task.task_family, record.name) for param_name, param_value in task.param_kwargs.items(): self.assertTrue(param_name in record.parameters) self.assertEqual(str(param_value), record.parameters[param_name].value) def test_task_blank_param(self): self.run_task(DummyTask(foo="")) with self.history._session() as session: tasks = list(self.history.find_all_by_name("DummyTask", session)) self.assertEqual(len(tasks), 1) task_record = tasks[0] self.assertEqual(task_record.name, "DummyTask") self.assertEqual(task_record.host, "hostname") self.assertIn("foo", task_record.parameters) self.assertEqual(task_record.parameters["foo"].value, "") def run_task(self, task): task2 = luigi.scheduler.Task( task.task_id, PENDING, [], family=task.task_family, params=task.param_kwargs, retry_policy=luigi.scheduler._get_empty_retry_policy() ) self.history.task_scheduled(task2) self.history.task_started(task2, "hostname") self.history.task_finished(task2, successful=True) class MySQLDbTaskHistoryTest(unittest.TestCase): @with_config(dict(task_history=dict(db_connection="mysql+mysqlconnector://travis@localhost/luigi_test"))) def setUp(self): try: self.history = DbTaskHistory() except Exception: raise unittest.SkipTest("DBTaskHistory cannot be created: probably no MySQL available") def test_subsecond_timestamp(self): with self.history._session() as session: # Add 2 events in <1s task = DummyTask() self.run_task(task) task_record = next(self.history.find_all_by_name("DummyTask", session)) print(task_record.events) self.assertEqual(task_record.events[0].event_name, DONE) def test_utc_conversion(self): from luigi.server import from_utc with self.history._session() as session: task = DummyTask() self.run_task(task) task_record = next(self.history.find_all_by_name("DummyTask", session)) last_event = task_record.events[0] try: print(from_utc(str(last_event.ts))) except ValueError: self.fail("Failed to convert timestamp {} to UTC".format(last_event.ts)) def run_task(self, task): task2 = luigi.scheduler.Task( task.task_id, PENDING, [], family=task.task_family, params=task.param_kwargs, retry_policy=luigi.scheduler._get_empty_retry_policy() ) self.history.task_scheduled(task2) self.history.task_started(task2, "hostname") self.history.task_finished(task2, successful=True) ================================================ FILE: test/decorator_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import datetime import pickle from helpers import unittest import luigi import luigi.notifications from luigi.mock import MockTarget from luigi.parameter import MissingParameterException from luigi.util import common_params, copies, delegates, inherits, requires luigi.notifications.DEBUG = True class A(luigi.Task): task_namespace = "decorator" # to prevent task name conflict between tests param1 = luigi.Parameter("class A-specific default") @inherits(A) class B(luigi.Task): param2 = luigi.Parameter("class B-specific default") @inherits(B) class C(luigi.Task): param3 = luigi.Parameter("class C-specific default") @inherits(B) class D(luigi.Task): param1 = luigi.Parameter("class D overwriting class A's default") @inherits(B) class D_null(luigi.Task): param1 = None @inherits(A, B) class E(luigi.Task): param4 = luigi.Parameter("class E-specific default") @inherits(A) @inherits(B) class E_stacked(luigi.Task): param4 = luigi.Parameter("class E-specific default") class InheritTest(unittest.TestCase): def setUp(self): self.a = A() self.a_changed = A(param1=34) self.b = B() self.c = C() self.d = D() self.d_null = D_null() self.e = E() self.e_stacked = E_stacked() def test_has_param(self): b_params = dict(self.b.get_params()).keys() self.assertTrue("param1" in b_params) def test_default_param(self): self.assertEqual(self.b.param1, self.a.param1) def test_change_of_defaults_not_equal(self): self.assertNotEqual(self.b.param1, self.a_changed.param1) def tested_chained_inheritance(self): self.assertEqual(self.c.param2, self.b.param2) self.assertEqual(self.c.param1, self.a.param1) self.assertEqual(self.c.param1, self.b.param1) def test_overwriting_defaults(self): self.assertEqual(self.d.param2, self.b.param2) self.assertNotEqual(self.d.param1, self.b.param1) self.assertNotEqual(self.d.param1, self.a.param1) self.assertEqual(self.d.param1, "class D overwriting class A's default") def test_multiple_inheritance(self): self.assertEqual(self.e.param1, self.a.param1) self.assertEqual(self.e.param1, self.b.param1) self.assertEqual(self.e.param2, self.b.param2) def test_stacked_inheritance(self): self.assertEqual(self.e_stacked.param1, self.a.param1) self.assertEqual(self.e_stacked.param1, self.b.param1) self.assertEqual(self.e_stacked.param2, self.b.param2) def test_empty_inheritance(self): with self.assertRaises(TypeError): @inherits() class shouldfail(luigi.Task): pass def test_removing_parameter(self): self.assertFalse("param1" in dict(self.d_null.get_params()).keys()) def test_wrapper_preserve_attributes(self): self.assertEqual(B.__name__, "B") class F(luigi.Task): param1 = luigi.Parameter("A parameter on a base task, that will be required later.") @inherits(F) class G(luigi.Task): param2 = luigi.Parameter("A separate parameter that doesn't affect 'F'") def requires(self): return F(**common_params(self, F)) @inherits(G) class H(luigi.Task): param2 = luigi.Parameter("OVERWRITING") def requires(self): return G(**common_params(self, G)) @inherits(G) class H_null(luigi.Task): param2 = None def requires(self): special_param2 = str(datetime.datetime.now()) return G(param2=special_param2, **common_params(self, G)) @inherits(G) class I_task(luigi.Task): def requires(self): return F(**common_params(self, F)) class J(luigi.Task): param1 = luigi.Parameter() # something required, with no default @inherits(J) class K_shouldnotinstantiate(luigi.Task): param2 = luigi.Parameter("A K-specific parameter") @inherits(J) class K_shouldfail(luigi.Task): param1 = None param2 = luigi.Parameter("A K-specific parameter") def requires(self): return J(**common_params(self, J)) @inherits(J) class K_shouldsucceed(luigi.Task): param1 = None param2 = luigi.Parameter("A K-specific parameter") def requires(self): return J(param1="Required parameter", **common_params(self, J)) @inherits(J) class K_wrongparamsorder(luigi.Task): param1 = None param2 = luigi.Parameter("A K-specific parameter") def requires(self): return J(param1="Required parameter", **common_params(J, self)) class RequiresTest(unittest.TestCase): def setUp(self): self.f = F() self.g = G() self.g_changed = G(param1="changing the default") self.h = H() self.h_null = H_null() self.i = I_task() self.k_shouldfail = K_shouldfail() self.k_shouldsucceed = K_shouldsucceed() self.k_wrongparamsorder = K_wrongparamsorder() def test_inherits(self): self.assertEqual(self.f.param1, self.g.param1) self.assertEqual(self.f.param1, self.g.requires().param1) def test_change_of_defaults(self): self.assertNotEqual(self.f.param1, self.g_changed.param1) self.assertNotEqual(self.g.param1, self.g_changed.param1) self.assertNotEqual(self.f.param1, self.g_changed.requires().param1) def test_overwriting_parameter(self): self.h.requires() self.assertNotEqual(self.h.param2, self.g.param2) self.assertEqual(self.h.param2, self.h.requires().param2) self.assertEqual(self.h.param2, "OVERWRITING") def test_skipping_one_inheritance(self): self.assertEqual(self.i.requires().param1, self.f.param1) def test_removing_parameter(self): self.assertNotEqual(self.h_null.requires().param2, self.g.param2) def test_not_setting_required_parameter(self): self.assertRaises(MissingParameterException, self.k_shouldfail.requires) def test_setting_required_parameters(self): self.k_shouldsucceed.requires() def test_should_not_instantiate(self): self.assertRaises(MissingParameterException, K_shouldnotinstantiate) def test_resuscitation(self): k = K_shouldnotinstantiate(param1="hello") k.requires() def test_wrong_common_params_order(self): self.assertRaises(TypeError, self.k_wrongparamsorder.requires) class V(luigi.Task): n = luigi.IntParameter(default=42) @inherits(V) class W(luigi.Task): def requires(self): return self.clone_parent() @requires(V) class W2(luigi.Task): pass @requires(V) class W3(luigi.Task): n = luigi.IntParameter(default=43) class X(luigi.Task): m = luigi.IntParameter(default=56) @requires(V, X) class Y(luigi.Task): pass class CloneParentTest(unittest.TestCase): def test_clone_parent(self): w = W() v = V() self.assertEqual(w.requires(), v) self.assertEqual(w.n, 42) def test_requires(self): w2 = W2() v = V() self.assertEqual(w2.requires(), v) self.assertEqual(w2.n, 42) def test_requires_override_default(self): w3 = W3() v = V() self.assertNotEqual(w3.requires(), v) self.assertEqual(w3.n, 43) self.assertEqual(w3.requires().n, 43) def test_multiple_requires(self): y = Y() v = V() x = X() self.assertEqual(y.requires()[0], v) self.assertEqual(y.requires()[1], x) def test_empty_requires(self): with self.assertRaises(TypeError): @requires() class shouldfail(luigi.Task): pass def test_names(self): # Just make sure the decorators retain the original class names v = V() self.assertEqual(str(v), "V(n=42)") self.assertEqual(v.__class__.__name__, "V") class P(luigi.Task): date = luigi.DateParameter() def output(self): return MockTarget(self.date.strftime("/tmp/data-%Y-%m-%d.txt")) def run(self): f = self.output().open("w") print("hello, world", file=f) f.close() @copies(P) class PCopy(luigi.Task): def output(self): return MockTarget(self.date.strftime("/tmp/copy-data-%Y-%m-%d.txt")) class CopyTest(unittest.TestCase): def test_copy(self): luigi.build([PCopy(date=datetime.date(2012, 1, 1))], local_scheduler=True) self.assertEqual(MockTarget.fs.get_data("/tmp/data-2012-01-01.txt"), b"hello, world\n") self.assertEqual(MockTarget.fs.get_data("/tmp/copy-data-2012-01-01.txt"), b"hello, world\n") class PickleTest(unittest.TestCase): def test_pickle(self): # similar to CopyTest.test_copy p = PCopy(date=datetime.date(2013, 1, 1)) p_pickled = pickle.dumps(p) p = pickle.loads(p_pickled) luigi.build([p], local_scheduler=True) self.assertEqual(MockTarget.fs.get_data("/tmp/data-2013-01-01.txt"), b"hello, world\n") self.assertEqual(MockTarget.fs.get_data("/tmp/copy-data-2013-01-01.txt"), b"hello, world\n") class Subtask(luigi.Task): k = luigi.IntParameter() def f(self, x): return x**self.k @delegates class SubtaskDelegator(luigi.Task): def subtasks(self): return [Subtask(1), Subtask(2)] def run(self): self.s = 0 for t in self.subtasks(): self.s += t.f(42) class SubtaskTest(unittest.TestCase): def test_subtasks(self): sd = SubtaskDelegator() luigi.build([sd], local_scheduler=True) self.assertEqual(sd.s, 42 * (1 + 42)) def test_forgot_subtasks(self): def trigger_failure(): @delegates class SubtaskDelegatorBroken(luigi.Task): pass self.assertRaises(AttributeError, trigger_failure) def test_cmdline(self): # Exposes issue where wrapped tasks are registered twice under # the same name from luigi.task import Register self.assertEqual(Register.get_task_cls("SubtaskDelegator"), SubtaskDelegator) ================================================ FILE: test/dict_parameter_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import collections import json import mock import pytest from helpers import in_parse, unittest from jsonschema import Draft4Validator from jsonschema.exceptions import ValidationError import luigi import luigi.interface class DictParameterTask(luigi.Task): param = luigi.DictParameter() class DictParameterTest(unittest.TestCase): _dict = collections.OrderedDict([("username", "me"), ("password", "secret")]) def test_parse(self): d = luigi.DictParameter().parse(json.dumps(DictParameterTest._dict)) self.assertEqual(d, DictParameterTest._dict) def test_serialize(self): d = luigi.DictParameter().serialize(DictParameterTest._dict) self.assertEqual(d, '{"username": "me", "password": "secret"}') def test_parse_and_serialize(self): inputs = ['{"username": "me", "password": "secret"}', '{"password": "secret", "username": "me"}'] for json_input in inputs: _dict = luigi.DictParameter().parse(json_input) self.assertEqual(json_input, luigi.DictParameter().serialize(_dict)) def test_parse_interface(self): in_parse( ["DictParameterTask", "--param", '{"username": "me", "password": "secret"}'], lambda task: self.assertEqual(task.param, DictParameterTest._dict) ) def test_serialize_task(self): t = DictParameterTask(DictParameterTest._dict) self.assertEqual(str(t), 'DictParameterTask(param={"username": "me", "password": "secret"})') def test_parse_invalid_input(self): self.assertRaises(ValueError, lambda: luigi.DictParameter().parse('{"invalid"}')) def test_hash_normalize(self): self.assertRaises(TypeError, lambda: hash(luigi.DictParameter().parse('{"a": {"b": []}}'))) a = luigi.DictParameter().normalize({"a": [{"b": []}]}) b = luigi.DictParameter().normalize({"a": [{"b": []}]}) self.assertEqual(hash(a), hash(b)) def test_schema(self): a = luigi.parameter.DictParameter( schema={ "type": "object", "properties": { "an_int": {"type": "integer"}, "an_optional_str": {"type": "string"}, }, "additionalProperties": False, "required": ["an_int"], }, ) # Check that the default value is validated with pytest.raises( ValidationError, match=r"Additional properties are not allowed \('INVALID_ATTRIBUTE' was unexpected\)", ): a.normalize({"INVALID_ATTRIBUTE": 0}) # Check that empty dict is not valid with pytest.raises(ValidationError, match="'an_int' is a required property"): a.normalize({}) # Check that valid dicts work a.normalize({"an_int": 1}) a.normalize({"an_int": 1, "an_optional_str": "hello"}) # Check that invalid dicts raise correct errors with pytest.raises(ValidationError, match="'999' is not of type 'integer'"): a.normalize({"an_int": "999"}) with pytest.raises(ValidationError, match="999 is not of type 'string'"): a.normalize({"an_int": 1, "an_optional_str": 999}) # Test the example given in docstring b = luigi.DictParameter( schema={ "type": "object", "patternProperties": { ".*": {"type": "string", "enum": ["web", "staging"]}, }, } ) b.normalize({"role": "web", "env": "staging"}) with pytest.raises(ValidationError, match=r"'UNKNOWN_VALUE' is not one of \['web', 'staging'\]"): b.normalize({"role": "UNKNOWN_VALUE", "env": "staging"}) # Check that warnings are properly emitted with mock.patch("luigi.parameter._JSONSCHEMA_ENABLED", False): with pytest.warns( UserWarning, match=("The 'jsonschema' package is not installed so the parameter can not be validated even though a schema is given.") ): luigi.ListParameter(schema={"type": "object"}) # Test with a custom validator validator = Draft4Validator( schema={ "type": "object", "patternProperties": { ".*": {"type": "string", "enum": ["web", "staging"]}, }, } ) c = luigi.DictParameter(schema=validator) c.normalize({"role": "web", "env": "staging"}) with pytest.raises(ValidationError, match=r"'UNKNOWN_VALUE' is not one of \['web', 'staging'\]"): c.normalize({"role": "UNKNOWN_VALUE", "env": "staging"}) # Test with frozen data frozen_data = luigi.freezing.recursively_freeze({"role": "web", "env": "staging"}) c.normalize(frozen_data) ================================================ FILE: test/dynamic_import_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from helpers import LuigiTestCase, temporary_unloaded_module import luigi import luigi.interface CONTENTS = b""" import luigi class FooTask(luigi.Task): x = luigi.IntParameter() def run(self): luigi._testing_glob_var = self.x """ class CmdlineTest(LuigiTestCase): def test_dynamic_loading(self): with temporary_unloaded_module(CONTENTS) as temp_module_name: luigi.interface.run(["--module", temp_module_name, "FooTask", "--x", "123", "--local-scheduler", "--no-lock"]) self.assertEqual(luigi._testing_glob_var, 123) ================================================ FILE: test/event_callbacks_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from helpers import unittest from mock import patch import luigi from luigi import Event, Task, build from luigi.mock import MockFileSystem, MockTarget from luigi.task import flatten class DummyException(Exception): pass class EmptyTask(Task): fail = luigi.BoolParameter() def run(self): self.trigger_event(Event.PROGRESS, self, {"foo": "bar"}) if self.fail: raise DummyException() class TaskWithBrokenDependency(Task): def requires(self): raise DummyException() def run(self): pass class TaskWithCallback(Task): def run(self): print("Triggering event") self.trigger_event("foo event") class TestEventCallbacks(unittest.TestCase): def test_start_handler(self): saved_tasks = [] @EmptyTask.event_handler(Event.START) def save_task(task): print("Saving task...") saved_tasks.append(task) t = EmptyTask(True) build([t], local_scheduler=True) self.assertEqual(saved_tasks, [t]) def _run_empty_task(self, fail): progresses = [] progresses_data = [] successes = [] failures = [] exceptions = [] @EmptyTask.event_handler(Event.SUCCESS) def success(task): successes.append(task) @EmptyTask.event_handler(Event.FAILURE) def failure(task, exception): failures.append(task) exceptions.append(exception) @EmptyTask.event_handler(Event.PROGRESS) def progress(task, data): progresses.append(task) progresses_data.append(data) t = EmptyTask(fail) build([t], local_scheduler=True) return t, progresses, progresses_data, successes, failures, exceptions def test_success(self): t, progresses, progresses_data, successes, failures, exceptions = self._run_empty_task(False) self.assertEqual(progresses, [t]) self.assertEqual(progresses_data, [{"foo": "bar"}]) self.assertEqual(successes, [t]) self.assertEqual(failures, []) self.assertEqual(exceptions, []) def test_failure(self): t, progresses, progresses_data, successes, failures, exceptions = self._run_empty_task(True) self.assertEqual(progresses, [t]) self.assertEqual(progresses_data, [{"foo": "bar"}]) self.assertEqual(successes, []) self.assertEqual(failures, [t]) self.assertEqual(len(exceptions), 1) self.assertTrue(isinstance(exceptions[0], DummyException)) def test_broken_dependency(self): failures = [] exceptions = [] @TaskWithBrokenDependency.event_handler(Event.BROKEN_TASK) def failure(task, exception): failures.append(task) exceptions.append(exception) t = TaskWithBrokenDependency() build([t], local_scheduler=True) self.assertEqual(failures, [t]) self.assertEqual(len(exceptions), 1) self.assertTrue(isinstance(exceptions[0], DummyException)) def test_custom_handler(self): dummies = [] @TaskWithCallback.event_handler("foo event") def story_dummy(): dummies.append("foo") t = TaskWithCallback() build([t], local_scheduler=True) self.assertEqual(dummies[0], "foo") def _run_processing_time_handler(self, fail): result = [] @EmptyTask.event_handler(Event.PROCESSING_TIME) def save_task(task, processing_time): result.append((task, processing_time)) times = [43.0, 1.0] t = EmptyTask(fail) with patch("luigi.worker.time") as mock: mock.time = times.pop build([t], local_scheduler=True) return t, result def test_processing_time_handler_success(self): t, result = self._run_processing_time_handler(False) self.assertEqual(len(result), 1) task, time = result[0] self.assertTrue(task is t) self.assertEqual(time, 42.0) def test_processing_time_handler_failure(self): t, result = self._run_processing_time_handler(True) self.assertEqual(result, []) def test_remove_event_handler(self): run_cnt = 0 @EmptyTask.event_handler(luigi.Event.START) def handler(task): nonlocal run_cnt run_cnt += 1 task = EmptyTask() build([task], local_scheduler=True) assert run_cnt == 1 EmptyTask.remove_event_handler(luigi.Event.START, handler) build([task], local_scheduler=True) assert run_cnt == 1 # A # / \ # B(1) B(2) # | | # C(1) C(2) # | \ | \ # D(1) D(2) D(3) def eval_contents(f): with f.open("r") as i: return eval(i.read()) class ConsistentMockOutput: """ Computes output location and contents from the task and its parameters. Rids us of writing ad-hoc boilerplate output() et al. """ param = luigi.IntParameter(default=1) def output(self): return MockTarget("/%s/%u" % (self.__class__.__name__, self.param)) def produce_output(self): with self.output().open("w") as o: o.write(repr([self.task_id] + sorted([eval_contents(i) for i in flatten(self.input())]))) class HappyTestFriend(ConsistentMockOutput, luigi.Task): """ Does trivial "work", outputting the list of inputs. Results in a convenient lispy comparable. """ def run(self): self.produce_output() class D(ConsistentMockOutput, luigi.ExternalTask): pass class C(HappyTestFriend): def requires(self): return [D(self.param), D(self.param + 1)] class B(HappyTestFriend): def requires(self): return C(self.param) class A(HappyTestFriend): task_namespace = "event_callbacks" # to prevent task name coflict between tests def requires(self): return [B(1), B(2)] class TestDependencyEvents(unittest.TestCase): def tearDown(self): MockFileSystem().remove("") def _run_test(self, task, expected_events): actual_events = {} # yucky to create separate callbacks; would be nicer if the callback # received an instance of a subclass of Event, so one callback could # accumulate all types @luigi.Task.event_handler(Event.DEPENDENCY_DISCOVERED) def callback_dependency_discovered(*args): actual_events.setdefault(Event.DEPENDENCY_DISCOVERED, set()).add(tuple(map(lambda t: t.task_id, args))) @luigi.Task.event_handler(Event.DEPENDENCY_MISSING) def callback_dependency_missing(*args): actual_events.setdefault(Event.DEPENDENCY_MISSING, set()).add(tuple(map(lambda t: t.task_id, args))) @luigi.Task.event_handler(Event.DEPENDENCY_PRESENT) def callback_dependency_present(*args): actual_events.setdefault(Event.DEPENDENCY_PRESENT, set()).add(tuple(map(lambda t: t.task_id, args))) build([task], local_scheduler=True) self.assertEqual(actual_events, expected_events) def test_incomplete_dag(self): for param in range(1, 3): D(param).produce_output() self._run_test( A(), { "event.core.dependency.discovered": { (A(param=1).task_id, B(param=1).task_id), (A(param=1).task_id, B(param=2).task_id), (B(param=1).task_id, C(param=1).task_id), (B(param=2).task_id, C(param=2).task_id), (C(param=1).task_id, D(param=1).task_id), (C(param=1).task_id, D(param=2).task_id), (C(param=2).task_id, D(param=2).task_id), (C(param=2).task_id, D(param=3).task_id), }, "event.core.dependency.missing": { (D(param=3).task_id,), }, "event.core.dependency.present": { (D(param=1).task_id,), (D(param=2).task_id,), }, }, ) self.assertFalse(A().output().exists()) def test_complete_dag(self): for param in range(1, 4): D(param).produce_output() self._run_test( A(), { "event.core.dependency.discovered": { (A(param=1).task_id, B(param=1).task_id), (A(param=1).task_id, B(param=2).task_id), (B(param=1).task_id, C(param=1).task_id), (B(param=2).task_id, C(param=2).task_id), (C(param=1).task_id, D(param=1).task_id), (C(param=1).task_id, D(param=2).task_id), (C(param=2).task_id, D(param=2).task_id), (C(param=2).task_id, D(param=3).task_id), }, "event.core.dependency.present": { (D(param=1).task_id,), (D(param=2).task_id,), (D(param=3).task_id,), }, }, ) self.assertEqual( eval_contents(A().output()), [ A(param=1).task_id, [B(param=1).task_id, [C(param=1).task_id, [D(param=1).task_id], [D(param=2).task_id]]], [B(param=2).task_id, [C(param=2).task_id, [D(param=2).task_id], [D(param=3).task_id]]], ], ) ================================================ FILE: test/execution_summary_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2015-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import datetime import threading from enum import Enum import mock from helpers import LuigiTestCase, RunOnceTask, with_config import luigi import luigi.execution_summary import luigi.worker class ExecutionSummaryTest(LuigiTestCase): def setUp(self): super(ExecutionSummaryTest, self).setUp() self.scheduler = luigi.scheduler.Scheduler(prune_on_get_work=False) self.worker = luigi.worker.Worker(scheduler=self.scheduler) def run_task(self, task): self.worker.add(task) # schedule self.worker.run() # run def summary_dict(self): return luigi.execution_summary._summary_dict(self.worker) def summary(self): return luigi.execution_summary.summary(self.worker) def test_all_statuses(self): class Bar(luigi.Task): num = luigi.IntParameter() def run(self): if self.num == 0: raise ValueError() def complete(self): if self.num == 1: return True return False class Foo(luigi.Task): def requires(self): for i in range(5): yield Bar(i) self.run_task(Foo()) d = self.summary_dict() self.assertEqual({Bar(num=1)}, d["already_done"]) self.assertEqual({Bar(num=2), Bar(num=3), Bar(num=4)}, d["completed"]) self.assertEqual({Bar(num=0)}, d["failed"]) self.assertEqual({Foo()}, d["upstream_failure"]) self.assertFalse(d["upstream_missing_dependency"]) self.assertFalse(d["run_by_other_worker"]) self.assertFalse(d["still_pending_ext"]) summary = self.summary() expected = [ "", "===== Luigi Execution Summary =====", "", "Scheduled 6 tasks of which:", "* 1 complete ones were encountered:", " - 1 Bar(num=1)", "* 3 ran successfully:", " - 3 Bar(num=2,3,4)", "* 1 failed:", " - 1 Bar(num=0)", "* 1 were left pending, among these:", " * 1 had failed dependencies:", " - 1 Foo()", "", "This progress looks :( because there were failed tasks", "", "===== Luigi Execution Summary =====", "", ] result = summary.split("\n") self.assertEqual(len(result), len(expected)) for i, line in enumerate(result): self.assertEqual(line, expected[i]) def test_batch_complete(self): ran_tasks = set() class MaxBatchTask(luigi.Task): param = luigi.IntParameter(batch_method=max) def run(self): ran_tasks.add(self.param) def complete(self): return any(self.param <= ran_param for ran_param in ran_tasks) class MaxBatches(luigi.WrapperTask): def requires(self): return map(MaxBatchTask, range(5)) self.run_task(MaxBatches()) d = self.summary_dict() expected_completed = { MaxBatchTask(0), MaxBatchTask(1), MaxBatchTask(2), MaxBatchTask(3), MaxBatchTask(4), MaxBatches(), } self.assertEqual(expected_completed, d["completed"]) def test_batch_fail(self): class MaxBatchFailTask(luigi.Task): param = luigi.IntParameter(batch_method=max) def run(self): assert self.param < 4 def complete(self): return False class MaxBatches(luigi.WrapperTask): def requires(self): return map(MaxBatchFailTask, range(5)) self.run_task(MaxBatches()) d = self.summary_dict() expected_failed = { MaxBatchFailTask(0), MaxBatchFailTask(1), MaxBatchFailTask(2), MaxBatchFailTask(3), MaxBatchFailTask(4), } self.assertEqual(expected_failed, d["failed"]) def test_check_complete_error(self): class Bar(luigi.Task): def run(self): pass def complete(self): raise Exception return True class Foo(luigi.Task): def requires(self): yield Bar() self.run_task(Foo()) d = self.summary_dict() self.assertEqual({Foo()}, d["still_pending_not_ext"]) self.assertEqual({Foo()}, d["upstream_scheduling_error"]) self.assertEqual({Bar()}, d["scheduling_error"]) self.assertFalse(d["not_run"]) self.assertFalse(d["already_done"]) self.assertFalse(d["completed"]) self.assertFalse(d["failed"]) self.assertFalse(d["upstream_failure"]) self.assertFalse(d["upstream_missing_dependency"]) self.assertFalse(d["run_by_other_worker"]) self.assertFalse(d["still_pending_ext"]) summary = self.summary() expected = [ "", "===== Luigi Execution Summary =====", "", "Scheduled 2 tasks of which:", "* 1 failed scheduling:", " - 1 Bar()", "* 1 were left pending, among these:", " * 1 had dependencies whose scheduling failed:", " - 1 Foo()", "", "Did not run any tasks", "This progress looks :( because there were tasks whose scheduling failed", "", "===== Luigi Execution Summary =====", "", ] result = summary.split("\n") self.assertEqual(len(result), len(expected)) for i, line in enumerate(result): self.assertEqual(line, expected[i]) def test_not_run_error(self): class Bar(luigi.Task): def complete(self): return True class Foo(luigi.Task): def requires(self): yield Bar() def new_func(*args, **kwargs): return None with mock.patch("luigi.scheduler.Scheduler.add_task", new_func): self.run_task(Foo()) d = self.summary_dict() self.assertEqual({Foo()}, d["still_pending_not_ext"]) self.assertEqual({Foo()}, d["not_run"]) self.assertEqual({Bar()}, d["already_done"]) self.assertFalse(d["upstream_scheduling_error"]) self.assertFalse(d["scheduling_error"]) self.assertFalse(d["completed"]) self.assertFalse(d["failed"]) self.assertFalse(d["upstream_failure"]) self.assertFalse(d["upstream_missing_dependency"]) self.assertFalse(d["run_by_other_worker"]) self.assertFalse(d["still_pending_ext"]) summary = self.summary() expected = [ "", "===== Luigi Execution Summary =====", "", "Scheduled 2 tasks of which:", "* 1 complete ones were encountered:", " - 1 Bar()", "* 1 were left pending, among these:", " * 1 was not granted run permission by the scheduler:", " - 1 Foo()", "", "Did not run any tasks", "This progress looks :| because there were tasks that were not granted run permission by the scheduler", "", "===== Luigi Execution Summary =====", "", ] result = summary.split("\n") self.assertEqual(len(result), len(expected)) for i, line in enumerate(result): self.assertEqual(line, expected[i]) def test_deps_error(self): class Bar(luigi.Task): def run(self): pass def complete(self): return True class Foo(luigi.Task): def requires(self): raise Exception yield Bar() self.run_task(Foo()) d = self.summary_dict() self.assertEqual({Foo()}, d["scheduling_error"]) self.assertFalse(d["upstream_scheduling_error"]) self.assertFalse(d["not_run"]) self.assertFalse(d["already_done"]) self.assertFalse(d["completed"]) self.assertFalse(d["failed"]) self.assertFalse(d["upstream_failure"]) self.assertFalse(d["upstream_missing_dependency"]) self.assertFalse(d["run_by_other_worker"]) self.assertFalse(d["still_pending_ext"]) summary = self.summary() expected = [ "", "===== Luigi Execution Summary =====", "", "Scheduled 1 tasks of which:", "* 1 failed scheduling:", " - 1 Foo()", "", "Did not run any tasks", "This progress looks :( because there were tasks whose scheduling failed", "", "===== Luigi Execution Summary =====", "", ] result = summary.split("\n") self.assertEqual(len(result), len(expected)) for i, line in enumerate(result): self.assertEqual(line, expected[i]) @with_config({"execution_summary": {"summary_length": "1"}}) def test_config_summary_limit(self): class Bar(luigi.Task): num = luigi.IntParameter() def run(self): pass def complete(self): return True class Biz(Bar): pass class Bat(Bar): pass class Wut(Bar): pass class Foo(luigi.Task): def requires(self): yield Bat(1) yield Wut(1) yield Biz(1) for i in range(4): yield Bar(i) def complete(self): return False self.run_task(Foo()) d = self.summary_dict() self.assertEqual({Bat(1), Wut(1), Biz(1), Bar(0), Bar(1), Bar(2), Bar(3)}, d["already_done"]) self.assertEqual({Foo()}, d["completed"]) self.assertFalse(d["failed"]) self.assertFalse(d["upstream_failure"]) self.assertFalse(d["upstream_missing_dependency"]) self.assertFalse(d["run_by_other_worker"]) self.assertFalse(d["still_pending_ext"]) summary = self.summary() expected = [ "", "===== Luigi Execution Summary =====", "", "Scheduled 8 tasks of which:", "* 7 complete ones were encountered:", " - 4 Bar(num=0...3)", " ...", "* 1 ran successfully:", " - 1 Foo()", "", "This progress looks :) because there were no failed tasks or missing dependencies", "", "===== Luigi Execution Summary =====", "", ] result = summary.split("\n") self.assertEqual(len(result), len(expected)) for i, line in enumerate(result): self.assertEqual(line, expected[i]) def test_upstream_not_running(self): class ExternalBar(luigi.ExternalTask): num = luigi.IntParameter() def complete(self): if self.num == 1: return True return False class Bar(luigi.Task): num = luigi.IntParameter() def run(self): if self.num == 0: raise ValueError() class Foo(luigi.Task): def requires(self): for i in range(5): yield ExternalBar(i) yield Bar(i) self.run_task(Foo()) d = self.summary_dict() self.assertEqual({ExternalBar(num=1)}, d["already_done"]) self.assertEqual({Bar(num=1), Bar(num=2), Bar(num=3), Bar(num=4)}, d["completed"]) self.assertEqual({Bar(num=0)}, d["failed"]) self.assertEqual({Foo()}, d["upstream_failure"]) self.assertEqual({Foo()}, d["upstream_missing_dependency"]) self.assertFalse(d["run_by_other_worker"]) self.assertEqual({ExternalBar(num=0), ExternalBar(num=2), ExternalBar(num=3), ExternalBar(num=4)}, d["still_pending_ext"]) s = self.summary() self.assertIn("\n* 1 complete ones were encountered:\n - 1 ExternalBar(num=1)\n", s) self.assertIn("\n* 4 ran successfully:\n - 4 Bar(num=1...4)\n", s) self.assertIn("\n* 1 failed:\n - 1 Bar(num=0)\n", s) self.assertIn("\n* 5 were left pending, among these:\n * 4 were missing external dependencies:\n - 4 ExternalBar(num=", s) self.assertIn( "\n * 1 had failed dependencies:\n" " - 1 Foo()\n" " * 1 had missing dependencies:\n" " - 1 Foo()\n\n" "This progress looks :( because there were failed tasks\n", s, ) self.assertNotIn("\n\n\n", s) def test_already_running(self): lock1 = threading.Lock() lock2 = threading.Lock() class ParentTask(RunOnceTask): def requires(self): yield LockTask() class LockTask(RunOnceTask): def run(self): lock2.release() lock1.acquire() self.comp = True lock1.acquire() lock2.acquire() other_worker = luigi.worker.Worker(scheduler=self.scheduler, worker_id="other_worker") other_worker.add(ParentTask()) t1 = threading.Thread(target=other_worker.run) t1.start() lock2.acquire() self.run_task(ParentTask()) lock1.release() t1.join() d = self.summary_dict() self.assertEqual({LockTask()}, d["run_by_other_worker"]) self.assertEqual({ParentTask()}, d["upstream_run_by_other_worker"]) s = self.summary() self.assertIn( "\nScheduled 2 tasks of which:\n" "* 2 were left pending, among these:\n" " * 1 were being run by another worker:\n" " - 1 LockTask()\n" " * 1 had dependencies that were being run by other worker:\n" " - 1 ParentTask()\n", s, ) self.assertIn( "\n\nThe other workers were:\n" " - other_worker ran 1 tasks\n\n" "Did not run any tasks\n" "This progress looks :) because there were no failed " "tasks or missing dependencies\n", s, ) self.assertNotIn("\n\n\n", s) def test_already_running_2(self): class AlreadyRunningTask(luigi.Task): def run(self): pass other_worker = luigi.worker.Worker(scheduler=self.scheduler, worker_id="other_worker") other_worker.add(AlreadyRunningTask()) # This also registers this worker old_func = luigi.scheduler.Scheduler.get_work def new_func(*args, **kwargs): new_kwargs = kwargs.copy() new_kwargs["worker"] = "other_worker" old_func(*args, **new_kwargs) return old_func(*args, **kwargs) with mock.patch("luigi.scheduler.Scheduler.get_work", new_func): self.run_task(AlreadyRunningTask()) d = self.summary_dict() self.assertFalse(d["already_done"]) self.assertFalse(d["completed"]) self.assertFalse(d["not_run"]) self.assertEqual({AlreadyRunningTask()}, d["run_by_other_worker"]) def test_not_run(self): class AlreadyRunningTask(luigi.Task): def run(self): pass other_worker = luigi.worker.Worker(scheduler=self.scheduler, worker_id="other_worker") other_worker.add(AlreadyRunningTask()) # This also registers this worker old_func = luigi.scheduler.Scheduler.get_work def new_func(*args, **kwargs): kwargs["current_tasks"] = None old_func(*args, **kwargs) return old_func(*args, **kwargs) with mock.patch("luigi.scheduler.Scheduler.get_work", new_func): self.run_task(AlreadyRunningTask()) d = self.summary_dict() self.assertFalse(d["already_done"]) self.assertFalse(d["completed"]) self.assertFalse(d["run_by_other_worker"]) self.assertEqual({AlreadyRunningTask()}, d["not_run"]) s = self.summary() self.assertIn( "\nScheduled 1 tasks of which:\n" "* 1 were left pending, among these:\n" " * 1 was not granted run permission by the scheduler:\n" " - 1 AlreadyRunningTask()\n", s, ) self.assertNotIn("\n\n\n", s) def test_somebody_else_finish_task(self): class SomeTask(RunOnceTask): pass other_worker = luigi.worker.Worker(scheduler=self.scheduler, worker_id="other_worker") self.worker.add(SomeTask()) other_worker.add(SomeTask()) other_worker.run() self.worker.run() d = self.summary_dict() self.assertFalse(d["already_done"]) self.assertFalse(d["completed"]) self.assertFalse(d["run_by_other_worker"]) self.assertEqual({SomeTask()}, d["not_run"]) def test_somebody_else_disables_task(self): class SomeTask(luigi.Task): def complete(self): return False def run(self): raise ValueError() other_worker = luigi.worker.Worker(scheduler=self.scheduler, worker_id="other_worker") self.worker.add(SomeTask()) other_worker.add(SomeTask()) other_worker.run() # Assuming it is disabled for a while after this self.worker.run() d = self.summary_dict() self.assertFalse(d["already_done"]) self.assertFalse(d["completed"]) self.assertFalse(d["run_by_other_worker"]) self.assertEqual({SomeTask()}, d["not_run"]) def test_larger_tree(self): class Dog(RunOnceTask): def requires(self): yield Cat(2) class Cat(luigi.Task): num = luigi.IntParameter() def __init__(self, *args, **kwargs): super(Cat, self).__init__(*args, **kwargs) self.comp = False def run(self): if self.num == 2: raise ValueError() self.comp = True def complete(self): if self.num == 1: return True else: return self.comp class Bar(RunOnceTask): num = luigi.IntParameter() def requires(self): if self.num == 0: yield ExternalBar() yield Cat(0) if self.num == 1: yield Cat(0) yield Cat(1) if self.num == 2: yield Dog() class Foo(luigi.Task): def requires(self): for i in range(3): yield Bar(i) class ExternalBar(luigi.ExternalTask): def complete(self): return False self.run_task(Foo()) d = self.summary_dict() self.assertEqual({Cat(num=1)}, d["already_done"]) self.assertEqual({Cat(num=0), Bar(num=1)}, d["completed"]) self.assertEqual({Cat(num=2)}, d["failed"]) self.assertEqual({Dog(), Bar(num=2), Foo()}, d["upstream_failure"]) self.assertEqual({Bar(num=0), Foo()}, d["upstream_missing_dependency"]) self.assertFalse(d["run_by_other_worker"]) self.assertEqual({ExternalBar()}, d["still_pending_ext"]) s = self.summary() self.assertNotIn("\n\n\n", s) def test_with_dates(self): """Just test that it doesn't crash with date params""" start = datetime.date(1998, 3, 23) class Bar(RunOnceTask): date = luigi.DateParameter() class Foo(luigi.Task): def requires(self): for i in range(10): new_date = start + datetime.timedelta(days=i) yield Bar(date=new_date) self.run_task(Foo()) d = self.summary_dict() exp_set = {Bar(start + datetime.timedelta(days=i)) for i in range(10)} exp_set.add(Foo()) self.assertEqual(exp_set, d["completed"]) s = self.summary() self.assertIn("date=1998-0", s) self.assertIn("Scheduled 11 tasks", s) self.assertIn("Luigi Execution Summary", s) self.assertNotIn("00:00:00", s) self.assertNotIn("\n\n\n", s) def test_with_ranges_minutes(self): start = datetime.datetime(1998, 3, 23, 1, 50) class Bar(RunOnceTask): time = luigi.DateMinuteParameter() class Foo(luigi.Task): def requires(self): for i in range(300): new_time = start + datetime.timedelta(minutes=i) yield Bar(time=new_time) self.run_task(Foo()) d = self.summary_dict() exp_set = {Bar(start + datetime.timedelta(minutes=i)) for i in range(300)} exp_set.add(Foo()) self.assertEqual(exp_set, d["completed"]) s = self.summary() self.assertIn("Bar(time=1998-03-23T0150...1998-03-23T0649)", s) self.assertNotIn("\n\n\n", s) def test_with_ranges_one_param(self): class Bar(RunOnceTask): num = luigi.IntParameter() class Foo(luigi.Task): def requires(self): for i in range(11): yield Bar(i) self.run_task(Foo()) d = self.summary_dict() exp_set = {Bar(i) for i in range(11)} exp_set.add(Foo()) self.assertEqual(exp_set, d["completed"]) s = self.summary() self.assertIn("Bar(num=0...10)", s) self.assertNotIn("\n\n\n", s) def test_with_ranges_multiple_params(self): class Bar(RunOnceTask): num1 = luigi.IntParameter() num2 = luigi.IntParameter() num3 = luigi.IntParameter() class Foo(luigi.Task): def requires(self): for i in range(5): yield Bar(5, i, 25) self.run_task(Foo()) d = self.summary_dict() exp_set = {Bar(5, i, 25) for i in range(5)} exp_set.add(Foo()) self.assertEqual(exp_set, d["completed"]) s = self.summary() self.assertIn("- 5 Bar(num1=5, num2=0...4, num3=25)", s) self.assertNotIn("\n\n\n", s) def test_with_two_tasks(self): class Bar(RunOnceTask): num = luigi.IntParameter() num2 = luigi.IntParameter() class Foo(luigi.Task): def requires(self): for i in range(2): yield Bar(i, 2 * i) self.run_task(Foo()) d = self.summary_dict() self.assertEqual({Foo(), Bar(num=0, num2=0), Bar(num=1, num2=2)}, d["completed"]) summary = self.summary() result = summary.split("\n") expected = [ "", "===== Luigi Execution Summary =====", "", "Scheduled 3 tasks of which:", "* 3 ran successfully:", " - 2 Bar(num=0, num2=0) and Bar(num=1, num2=2)", " - 1 Foo()", "", "This progress looks :) because there were no failed tasks or missing dependencies", "", "===== Luigi Execution Summary =====", "", ] self.assertEqual(len(result), len(expected)) for i, line in enumerate(result): self.assertEqual(line, expected[i]) def test_really_long_param_name(self): class Bar(RunOnceTask): This_is_a_really_long_parameter_that_we_should_not_print_out_because_people_will_get_annoyed = luigi.IntParameter() class Foo(luigi.Task): def requires(self): yield Bar(0) self.run_task(Foo()) s = self.summary() self.assertIn("Bar(...)", s) self.assertNotIn("Did not run any tasks", s) self.assertNotIn("\n\n\n", s) def test_multiple_params_multiple_same_task_family(self): class Bar(RunOnceTask): num = luigi.IntParameter() num2 = luigi.IntParameter() class Foo(luigi.Task): def requires(self): for i in range(4): yield Bar(i, 2 * i) self.run_task(Foo()) summary = self.summary() result = summary.split("\n") expected = [ "", "===== Luigi Execution Summary =====", "", "Scheduled 5 tasks of which:", "* 5 ran successfully:", " - 4 Bar(num=0, num2=0) ...", " - 1 Foo()", "", "This progress looks :) because there were no failed tasks or missing dependencies", "", "===== Luigi Execution Summary =====", "", ] self.assertEqual(len(result), len(expected)) for i, line in enumerate(result): self.assertEqual(line, expected[i]) def test_happy_smiley_face_normal(self): class Bar(RunOnceTask): num = luigi.IntParameter() num2 = luigi.IntParameter() class Foo(luigi.Task): def requires(self): for i in range(4): yield Bar(i, 2 * i) self.run_task(Foo()) s = self.summary() self.assertIn("\nThis progress looks :) because there were no failed tasks or missing dependencies", s) self.assertNotIn("Did not run any tasks", s) self.assertNotIn("\n\n\n", s) def test_happy_smiley_face_other_workers(self): lock1 = threading.Lock() lock2 = threading.Lock() class ParentTask(RunOnceTask): def requires(self): yield LockTask() class LockTask(RunOnceTask): def run(self): lock2.release() lock1.acquire() self.comp = True lock1.acquire() lock2.acquire() other_worker = luigi.worker.Worker(scheduler=self.scheduler, worker_id="other_worker") other_worker.add(ParentTask()) t1 = threading.Thread(target=other_worker.run) t1.start() lock2.acquire() self.run_task(ParentTask()) lock1.release() t1.join() s = self.summary() self.assertIn("\nThis progress looks :) because there were no failed tasks or missing dependencies", s) self.assertNotIn("\n\n\n", s) def test_sad_smiley_face(self): class ExternalBar(luigi.ExternalTask): def complete(self): return False class Bar(luigi.Task): num = luigi.IntParameter() def run(self): if self.num == 0: raise ValueError() class Foo(luigi.Task): def requires(self): for i in range(5): yield Bar(i) yield ExternalBar() self.run_task(Foo()) s = self.summary() self.assertIn("\nThis progress looks :( because there were failed tasks", s) self.assertNotIn("Did not run any tasks", s) self.assertNotIn("\n\n\n", s) def test_neutral_smiley_face(self): class ExternalBar(luigi.ExternalTask): def complete(self): return False class Foo(luigi.Task): def requires(self): yield ExternalBar() self.run_task(Foo()) s = self.summary() self.assertIn("\nThis progress looks :| because there were missing external dependencies", s) self.assertNotIn("\n\n\n", s) def test_did_not_run_any_tasks(self): class ExternalBar(luigi.ExternalTask): num = luigi.IntParameter() def complete(self): if self.num == 5: return True return False class Foo(luigi.Task): def requires(self): for i in range(10): yield ExternalBar(i) self.run_task(Foo()) d = self.summary_dict() self.assertEqual({ExternalBar(5)}, d["already_done"]) self.assertEqual({ExternalBar(i) for i in range(10) if i != 5}, d["still_pending_ext"]) self.assertEqual({Foo()}, d["upstream_missing_dependency"]) s = self.summary() self.assertIn("\n\nDid not run any tasks\nThis progress looks :| because there were missing external dependencies", s) self.assertNotIn("\n\n\n", s) def test_example(self): class MyExternal(luigi.ExternalTask): def complete(self): return False class Boom(luigi.Task): this_is_a_really_long_I_mean_way_too_long_and_annoying_parameter = luigi.IntParameter() def requires(self): for i in range(5, 200): yield Bar(i) class Foo(luigi.Task): num = luigi.IntParameter() num2 = luigi.IntParameter() def requires(self): yield MyExternal() yield Boom(0) class Bar(luigi.Task): num = luigi.IntParameter() def complete(self): return True class DateTask(luigi.Task): date = luigi.DateParameter() num = luigi.IntParameter() def requires(self): yield MyExternal() yield Boom(0) class EntryPoint(luigi.Task): def requires(self): for i in range(10): yield Foo(100, 2 * i) for i in range(10): yield DateTask(datetime.date(1998, 3, 23) + datetime.timedelta(days=i), 5) self.run_task(EntryPoint()) summary = self.summary() expected = [ "", "===== Luigi Execution Summary =====", "", "Scheduled 218 tasks of which:", "* 195 complete ones were encountered:", " - 195 Bar(num=5...199)", "* 1 ran successfully:", " - 1 Boom(...)", "* 22 were left pending, among these:", " * 1 were missing external dependencies:", " - 1 MyExternal()", " * 21 had missing dependencies:", " - 10 DateTask(date=1998-03-23...1998-04-01, num=5)", " - 1 EntryPoint()", " - 10 Foo(num=100, num2=0) ...", "", "This progress looks :| because there were missing external dependencies", "", "===== Luigi Execution Summary =====", "", ] result = summary.split("\n") self.assertEqual(len(result), len(expected)) for i, line in enumerate(result): self.assertEqual(line, expected[i]) def test_with_datehours(self): """Just test that it doesn't crash with datehour params""" start = datetime.datetime(1998, 3, 23, 5) class Bar(RunOnceTask): datehour = luigi.DateHourParameter() class Foo(luigi.Task): def requires(self): for i in range(10): new_date = start + datetime.timedelta(hours=i) yield Bar(datehour=new_date) self.run_task(Foo()) d = self.summary_dict() exp_set = {Bar(start + datetime.timedelta(hours=i)) for i in range(10)} exp_set.add(Foo()) self.assertEqual(exp_set, d["completed"]) s = self.summary() self.assertIn("datehour=1998-03-23T0", s) self.assertIn("Scheduled 11 tasks", s) self.assertIn("Luigi Execution Summary", s) self.assertNotIn("00:00:00", s) self.assertNotIn("\n\n\n", s) def test_with_months(self): """Just test that it doesn't crash with month params""" start = datetime.datetime(1998, 3, 23) class Bar(RunOnceTask): month = luigi.MonthParameter() class Foo(luigi.Task): def requires(self): for i in range(3): new_date = start + datetime.timedelta(days=30 * i) yield Bar(month=new_date) self.run_task(Foo()) d = self.summary_dict() exp_set = {Bar(start + datetime.timedelta(days=30 * i)) for i in range(3)} exp_set.add(Foo()) self.assertEqual(exp_set, d["completed"]) s = self.summary() self.assertIn("month=1998-0", s) self.assertIn("Scheduled 4 tasks", s) self.assertIn("Luigi Execution Summary", s) self.assertNotIn("00:00:00", s) self.assertNotIn("\n\n\n", s) def test_multiple_dash_dash_workers(self): """ Don't print own worker with ``--workers 2`` setting. """ self.worker = luigi.worker.Worker(scheduler=self.scheduler, worker_processes=2) class Foo(RunOnceTask): pass self.run_task(Foo()) d = self.summary_dict() self.assertEqual(set(), d["run_by_other_worker"]) s = self.summary() self.assertNotIn("The other workers were", s) self.assertIn("This progress looks :) because there were no failed ", s) self.assertNotIn("\n\n\n", s) def test_with_uncomparable_parameters(self): """ Don't rely on parameters being sortable """ class Color(Enum): red = 1 yellow = 2 class Bar(RunOnceTask): eparam = luigi.EnumParameter(enum=Color) class Baz(RunOnceTask): eparam = luigi.EnumParameter(enum=Color) another_param = luigi.IntParameter() class Foo(luigi.Task): def requires(self): yield Bar(Color.red) yield Bar(Color.yellow) yield Baz(Color.red, 5) yield Baz(Color.yellow, 5) self.run_task(Foo()) s = self.summary() self.assertIn("yellow", s) def test_with_dict_dependency(self): """Just test that it doesn't crash with dict params in dependencies""" args = dict(start=datetime.date(1998, 3, 23), num=3) class Bar(RunOnceTask): args = luigi.DictParameter() class Foo(luigi.Task): def requires(self): for i in range(10): new_dict = args.copy() new_dict["start"] = str(new_dict["start"] + datetime.timedelta(days=i)) yield Bar(args=new_dict) self.run_task(Foo()) d = self.summary_dict() exp_set = set() for i in range(10): new_dict = args.copy() new_dict["start"] = str(new_dict["start"] + datetime.timedelta(days=i)) exp_set.add(Bar(new_dict)) exp_set.add(Foo()) self.assertEqual(exp_set, d["completed"]) s = self.summary() self.assertIn('"num": 3', s) self.assertIn('"start": "1998-0', s) self.assertIn("Scheduled 11 tasks", s) self.assertIn("Luigi Execution Summary", s) self.assertNotIn("00:00:00", s) self.assertNotIn("\n\n\n", s) def test_with_dict_argument(self): """Just test that it doesn't crash with dict params""" args = dict(start=str(datetime.date(1998, 3, 23)), num=3) class Bar(RunOnceTask): args = luigi.DictParameter() self.run_task(Bar(args=args)) d = self.summary_dict() exp_set = set() exp_set.add(Bar(args=args)) self.assertEqual(exp_set, d["completed"]) s = self.summary() self.assertIn('"num": 3', s) self.assertIn('"start": "1998-0', s) self.assertIn("Scheduled 1 task", s) self.assertIn("Luigi Execution Summary", s) self.assertNotIn("00:00:00", s) self.assertNotIn("\n\n\n", s) """ Test that a task once crashing and then succeeding should be counted as no failure. """ def test_status_with_task_retry(self): class Foo(luigi.Task): run_count = 0 def run(self): self.run_count += 1 if self.run_count == 1: raise ValueError() def complete(self): return self.run_count > 0 self.run_task(Foo()) self.run_task(Foo()) d = self.summary_dict() self.assertEqual({Foo()}, d["completed"]) self.assertEqual({Foo()}, d["ever_failed"]) self.assertFalse(d["failed"]) self.assertFalse(d["upstream_failure"]) self.assertFalse(d["upstream_missing_dependency"]) self.assertFalse(d["run_by_other_worker"]) self.assertFalse(d["still_pending_ext"]) s = self.summary() self.assertIn("Scheduled 1 task", s) self.assertIn("Luigi Execution Summary", s) self.assertNotIn("ever failed", s) self.assertIn("\n\nThis progress looks :) because there were failed tasks but they all succeeded in a retry", s) ================================================ FILE: test/factorial_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from helpers import unittest import luigi class Factorial(luigi.Task): """This calculates factorials *online* and does not write its results anywhere Demonstrates the ability for dependencies between Tasks and not just between their output. """ n = luigi.IntParameter(default=100) def requires(self): if self.n > 1: return Factorial(self.n - 1) def run(self): if self.n > 1: self.value = self.n * self.requires().value else: self.value = 1 self.complete = lambda: True def complete(self): return False class FactorialTest(unittest.TestCase): def test_invoke(self): luigi.build([Factorial(100)], local_scheduler=True) self.assertEqual(Factorial(42).value, 1405006117752879898543142606244511569936384000000000) ================================================ FILE: test/fib_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from helpers import unittest import luigi import luigi.interface from luigi.mock import MockTarget # Calculates Fibonacci numbers :) class Fib(luigi.Task): n = luigi.IntParameter(default=100) def requires(self): if self.n >= 2: return [Fib(self.n - 1), Fib(self.n - 2)] else: return [] def output(self): return MockTarget("/tmp/fib_%d" % self.n) def run(self): if self.n == 0: s = 0 elif self.n == 1: s = 1 else: s = 0 for input in self.input(): for line in input.open("r"): s += int(line.strip()) f = self.output().open("w") f.write("%d\n" % s) f.close() class FibTestBase(unittest.TestCase): def setUp(self): MockTarget.fs.clear() class FibTest(FibTestBase): def test_invoke(self): luigi.build([Fib(100)], local_scheduler=True) self.assertEqual(MockTarget.fs.get_data("/tmp/fib_10"), b"55\n") self.assertEqual(MockTarget.fs.get_data("/tmp/fib_100"), b"354224848179261915075\n") def test_cmdline(self): luigi.run(["--local-scheduler", "--no-lock", "Fib", "--n", "100"]) self.assertEqual(MockTarget.fs.get_data("/tmp/fib_10"), b"55\n") self.assertEqual(MockTarget.fs.get_data("/tmp/fib_100"), b"354224848179261915075\n") def test_build_internal(self): luigi.build([Fib(100)], local_scheduler=True) self.assertEqual(MockTarget.fs.get_data("/tmp/fib_10"), b"55\n") self.assertEqual(MockTarget.fs.get_data("/tmp/fib_100"), b"354224848179261915075\n") ================================================ FILE: test/hdfs_client_test.py ================================================ import itertools import threading import unittest from luigi.contrib.hdfs import get_autoconfig_client class HdfsClientTest(unittest.TestCase): def test_get_autoconfig_client_cached(self): original_client = get_autoconfig_client() for _ in range(100): self.assertIs(original_client, get_autoconfig_client()) def test_threaded_clients_different(self): clients = [] def add_client(): clients.append(get_autoconfig_client()) # run a bunch of threads to get new clients in them threads = [threading.Thread(target=add_client) for _ in range(10)] for thread in threads: thread.start() for thread in threads: thread.join() for client1, client2 in itertools.combinations(clients, 2): self.assertIsNot(client1, client2) ================================================ FILE: test/helpers.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import functools import itertools import os import re import tempfile import unittest from contextlib import contextmanager import luigi import luigi.cmdline_parser import luigi.task_register from luigi.cmdline_parser import CmdlineParser def skipOnTravisAndGithubActions(reason): if _override_skip_CI_tests(): # Do not skip the CI tests return unittest.skipIf(False, "") # run the skip CI tests logic return unittest.skipIf(_running_on_travis() or _running_on_github_actions(), reason) def skipOnGithubActions(reason): return unittest.skipIf(_running_on_github_actions(), reason) def _running_on_travis(): return os.getenv("TRAVIS") == "true" def _running_on_github_actions(): return os.getenv("GITHUB_ACTIONS") == "true" def _override_skip_CI_tests(): return os.getenv("OVERRIDE_SKIP_CI_TESTS") == "true" class with_config: """ Decorator to override config settings for the length of a function. Usage: .. code-block: python >>> import luigi.configuration >>> @with_config({'foo': {'bar': 'baz'}}) ... def my_test(): ... print(luigi.configuration.get_config().get("foo", "bar")) ... >>> my_test() baz >>> @with_config({'hoo': {'bar': 'buz'}}) ... @with_config({'foo': {'bar': 'baz'}}) ... def my_test(): ... print(luigi.configuration.get_config().get("foo", "bar")) ... print(luigi.configuration.get_config().get("hoo", "bar")) ... >>> my_test() baz buz >>> @with_config({'foo': {'bar': 'buz'}}) ... @with_config({'foo': {'bar': 'baz'}}) ... def my_test(): ... print(luigi.configuration.get_config().get("foo", "bar")) ... >>> my_test() baz >>> @with_config({'foo': {'bur': 'buz'}}) ... @with_config({'foo': {'bar': 'baz'}}) ... def my_test(): ... print(luigi.configuration.get_config().get("foo", "bar")) ... print(luigi.configuration.get_config().get("foo", "bur")) ... >>> my_test() baz buz >>> @with_config({'foo': {'bur': 'buz'}}) ... @with_config({'foo': {'bar': 'baz'}}, replace_sections=True) ... def my_test(): ... print(luigi.configuration.get_config().get("foo", "bar")) ... print(luigi.configuration.get_config().get("foo", "bur", "no_bur")) ... >>> my_test() baz no_bur """ def __init__(self, config, replace_sections=False): self.config = config self.replace_sections = replace_sections def _make_dict(self, old_dict): if self.replace_sections: old_dict.update(self.config) return old_dict def get_section(sec): old_sec = old_dict.get(sec, {}) new_sec = self.config.get(sec, {}) old_sec.update(new_sec) return old_sec all_sections = itertools.chain(old_dict.keys(), self.config.keys()) return {sec: get_section(sec) for sec in all_sections} def __call__(self, fun): @functools.wraps(fun) def wrapper(*args, **kwargs): import luigi.configuration orig_conf = luigi.configuration.LuigiConfigParser.instance() new_conf = luigi.configuration.LuigiConfigParser() luigi.configuration.LuigiConfigParser._instance = new_conf orig_dict = {k: dict(orig_conf.items(k)) for k in orig_conf.sections()} new_dict = self._make_dict(orig_dict) for section, settings in new_dict.items(): new_conf.add_section(section) for name, value in settings.items(): new_conf.set(section, name, value) try: return fun(*args, **kwargs) finally: luigi.configuration.LuigiConfigParser._instance = orig_conf return wrapper class RunOnceTask(luigi.Task): def __init__(self, *args, **kwargs): super(RunOnceTask, self).__init__(*args, **kwargs) self.comp = False def complete(self): return self.comp def run(self): self.comp = True # string subclass that matches arguments containing the specified substring # for use in mock 'called_with' assertions class StringContaining(str): def __eq__(self, other_str): return self in other_str class LuigiTestCase(unittest.TestCase): """ Tasks registred within a test case will get unregistered in a finalizer Instance caches are cleared before and after all runs """ def setUp(self): super(LuigiTestCase, self).setUp() self._stashed_reg = luigi.task_register.Register._get_reg() luigi.task_register.Register.clear_instance_cache() def tearDown(self): luigi.task_register.Register._set_reg(self._stashed_reg) super(LuigiTestCase, self).tearDown() luigi.task_register.Register.clear_instance_cache() def run_locally(self, args): """Helper for running tests testing more of the stack, the command line parsing and task from name intstantiation parts in particular.""" temp = CmdlineParser._instance try: CmdlineParser._instance = None run_exit_status = luigi.run(["--local-scheduler", "--no-lock"] + args) finally: CmdlineParser._instance = temp return run_exit_status def run_locally_split(self, space_seperated_args): """Helper for running tests testing more of the stack, the command line parsing and task from name intstantiation parts in particular.""" return self.run_locally(space_seperated_args.split(" ")) class parsing: """ Convenient decorator for test cases to set the parsing environment. """ def __init__(self, cmds): self.cmds = cmds def __call__(self, fun): @functools.wraps(fun) def wrapper(*args, **kwargs): with CmdlineParser.global_instance(self.cmds, allow_override=True): return fun(*args, **kwargs) return wrapper def in_parse(cmds, deferred_computation): with CmdlineParser.global_instance(cmds) as cp: deferred_computation(cp.get_task_obj()) @contextmanager def temporary_unloaded_module(python_file_contents): """Create an importable module Return the name of importable module name given its file contents (source code)""" with tempfile.NamedTemporaryFile(dir="test/", prefix="_test_time_generated_module", suffix=".py") as temp_module_file: temp_module_file.file.write(python_file_contents) temp_module_file.file.flush() temp_module_path = temp_module_file.name temp_module_name = re.search(r"/(_test_time_generated_module.*).py", temp_module_path).group(1) yield temp_module_name ================================================ FILE: test/helpers_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2016 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from helpers import LuigiTestCase, RunOnceTask import luigi import luigi.date_interval import luigi.interface import luigi.notifications class LuigiTestCaseTest(LuigiTestCase): def test_1(self): class MyClass(luigi.Task): pass self.assertTrue(self.run_locally(["MyClass"])) def test_2(self): class MyClass(luigi.Task): pass self.assertTrue(self.run_locally(["MyClass"])) class RunOnceTaskTest(LuigiTestCase): def test_complete_behavior(self): """ Verify that RunOnceTask works as expected. This task will fail if it is a normal ``luigi.Task``, because RequiringTask will not run (missing dependency at runtime). """ class MyTask(RunOnceTask): pass class RequiringTask(luigi.Task): counter = 0 def requires(self): yield MyTask() def run(self): RequiringTask.counter += 1 self.run_locally(["RequiringTask"]) self.assertEqual(1, RequiringTask.counter) ================================================ FILE: test/import_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import os from helpers import unittest class ImportTest(unittest.TestCase): def import_test(self): """Test that all module can be imported""" luigidir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..") packagedir = os.path.join(luigidir, "luigi") for root, subdirs, files in os.walk(packagedir): package = os.path.relpath(root, luigidir).replace("/", ".") if "__init__.py" in files: __import__(package) for f in files: if f.endswith(".py") and not f.startswith("_"): __import__(package + "." + f[:-3]) def import_luigi_test(self): """ Test that the top luigi package can be imported and contains the usual suspects. """ import luigi # These should exist (if not, this will cause AttributeErrors) expected = [ luigi.Event, luigi.Config, luigi.Task, luigi.ExternalTask, luigi.WrapperTask, luigi.Target, luigi.LocalTarget, luigi.namespace, luigi.RemoteScheduler, luigi.RPCError, luigi.run, luigi.build, luigi.Parameter, luigi.DateHourParameter, luigi.DateMinuteParameter, luigi.DateSecondParameter, luigi.DateParameter, luigi.MonthParameter, luigi.YearParameter, luigi.DateIntervalParameter, luigi.TimeDeltaParameter, luigi.IntParameter, luigi.FloatParameter, luigi.BoolParameter, ] self.assertGreater(len(expected), 0) ================================================ FILE: test/instance_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from helpers import unittest import luigi import luigi.date_interval import luigi.notifications import luigi.worker luigi.notifications.DEBUG = True class InstanceTest(unittest.TestCase): def test_simple(self): class DummyTask(luigi.Task): x = luigi.Parameter() dummy_1 = DummyTask(1) dummy_2 = DummyTask(2) dummy_1b = DummyTask(1) self.assertNotEqual(dummy_1, dummy_2) self.assertEqual(dummy_1, dummy_1b) def test_dep(self): test = self class A(luigi.Task): task_namespace = "instance" # to prevent task name conflict between tests def __init__(self): self.has_run = False super(A, self).__init__() def run(self): self.has_run = True class B(luigi.Task): x = luigi.Parameter() def requires(self): return A() # This will end up referring to the same object def run(self): test.assertTrue(self.requires().has_run) luigi.build([B(1), B(2)], local_scheduler=True) def test_external_instance_cache(self): class A(luigi.Task): task_namespace = "instance" # to prevent task name conflict between tests pass class OtherA(luigi.ExternalTask): task_family = "A" oa = OtherA() a = A() self.assertNotEqual(oa, a) def test_date(self): """Adding unit test because we had a problem with this""" class DummyTask(luigi.Task): x = luigi.DateIntervalParameter() dummy_1 = DummyTask(luigi.date_interval.Year(2012)) dummy_2 = DummyTask(luigi.date_interval.Year(2013)) dummy_1b = DummyTask(luigi.date_interval.Year(2012)) self.assertNotEqual(dummy_1, dummy_2) self.assertEqual(dummy_1, dummy_1b) def test_unhashable_type(self): # See #857 class DummyTask(luigi.Task): x = luigi.Parameter() dummy = DummyTask(x={}) # NOQA ================================================ FILE: test/instance_wrap_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import datetime import decimal from helpers import unittest import luigi import luigi.notifications from luigi.mock import MockTarget luigi.notifications.DEBUG = True class Report(luigi.Task): date = luigi.DateParameter() def run(self): f = self.output().open("w") f.write("10.0 USD\n") f.write("4.0 EUR\n") f.write("3.0 USD\n") f.close() def output(self): return MockTarget(self.date.strftime("/tmp/report-%Y-%m-%d")) class ReportReader(luigi.Task): date = luigi.DateParameter() def requires(self): return Report(self.date) def run(self): self.lines = list(self.input().open("r").readlines()) def get_line(self, line): amount, currency = self.lines[line].strip().split() return decimal.Decimal(amount), currency def complete(self): return False class CurrencyExchanger(luigi.Task): task = luigi.Parameter() currency_to = luigi.Parameter() exchange_rates = {("USD", "USD"): decimal.Decimal(1), ("EUR", "USD"): decimal.Decimal("1.25")} def requires(self): return self.task # Note that you still need to state this explicitly def get_line(self, line): amount, currency_from = self.task.get_line(line) return amount * self.exchange_rates[(currency_from, self.currency_to)], self.currency_to def complete(self): return False class InstanceWrapperTest(unittest.TestCase): """This test illustrates that tasks can have tasks as parameters This is a more complicated variant of factorial_test.py which is an example of tasks communicating directly with other tasks. In this case, a task takes another task as a parameter and wraps it. Also see wrap_test.py for an example of a task class wrapping another task class. Not the most useful pattern, but there's actually been a few cases where it was pretty handy to be able to do that. I'm adding it as a unit test to make sure that new code doesn't break the expected behavior. """ def test(self): d = datetime.date(2012, 1, 1) r = ReportReader(d) ex = CurrencyExchanger(r, "USD") luigi.build([ex], local_scheduler=True) self.assertEqual(ex.get_line(0), (decimal.Decimal("10.0"), "USD")) self.assertEqual(ex.get_line(1), (decimal.Decimal("5.0"), "USD")) ================================================ FILE: test/interface_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import sys from helpers import LuigiTestCase, with_config from mock import MagicMock, Mock, patch import luigi import luigi.date_interval import luigi.notifications from luigi.execution_summary import LuigiStatusCode from luigi.interface import _WorkerSchedulerFactory, core from luigi.worker import Worker luigi.notifications.DEBUG = True class InterfaceTest(LuigiTestCase): def setUp(self): self.worker = Worker() self.worker_scheduler_factory = _WorkerSchedulerFactory() self.worker_scheduler_factory.create_worker = Mock(return_value=self.worker) self.worker_scheduler_factory.create_local_scheduler = Mock() super(InterfaceTest, self).setUp() class NoOpTask(luigi.Task): param = luigi.Parameter() self.task_a = NoOpTask("a") self.task_b = NoOpTask("b") def _create_summary_dict_with(self, updates={}): summary_dict = { "completed": set(), "already_done": set(), "ever_failed": set(), "failed": set(), "scheduling_error": set(), "still_pending_ext": set(), "still_pending_not_ext": set(), "run_by_other_worker": set(), "upstream_failure": set(), "upstream_missing_dependency": set(), "upstream_run_by_other_worker": set(), "upstream_scheduling_error": set(), "not_run": set(), } summary_dict.update(updates) return summary_dict def _summary_dict_module_path(): return "luigi.execution_summary._summary_dict" def test_interface_run_positive_path(self): self.worker.add = Mock(side_effect=[True, True]) self.worker.run = Mock(return_value=True) self.assertTrue(self._run_interface()) def test_interface_run_positive_path_with_detailed_summary_enabled(self): self.worker.add = Mock(side_effect=[True, True]) self.worker.run = Mock(return_value=True) self.assertTrue(self._run_interface(detailed_summary=True).scheduling_succeeded) def test_interface_run_with_add_failure(self): self.worker.add = Mock(side_effect=[True, False]) self.worker.run = Mock(return_value=True) self.assertFalse(self._run_interface()) def test_interface_run_with_add_failure_with_detailed_summary_enabled(self): self.worker.add = Mock(side_effect=[True, False]) self.worker.run = Mock(return_value=True) self.assertFalse(self._run_interface(detailed_summary=True).scheduling_succeeded) def test_interface_run_with_run_failure(self): self.worker.add = Mock(side_effect=[True, True]) self.worker.run = Mock(return_value=False) self.assertFalse(self._run_interface()) def test_interface_run_with_run_failure_with_detailed_summary_enabled(self): self.worker.add = Mock(side_effect=[True, True]) self.worker.run = Mock(return_value=False) self.assertFalse(self._run_interface(detailed_summary=True).scheduling_succeeded) @patch(_summary_dict_module_path()) def test_that_status_is_success(self, fake_summary_dict): # Nothing in failed tasks so, should succeed fake_summary_dict.return_value = self._create_summary_dict_with() luigi_run_result = self._run_interface(detailed_summary=True) self.assertEqual(luigi_run_result.status, LuigiStatusCode.SUCCESS) @patch(_summary_dict_module_path()) def test_that_status_is_success_with_retry(self, fake_summary_dict): # Nothing in failed tasks (only an entry in ever_failed) so, should succeed with retry fake_summary_dict.return_value = self._create_summary_dict_with({"ever_failed": [self.task_a]}) luigi_run_result = self._run_interface(detailed_summary=True) self.assertEqual(luigi_run_result.status, LuigiStatusCode.SUCCESS_WITH_RETRY) @patch(_summary_dict_module_path()) def test_that_status_is_failed_when_there_is_one_failed_task(self, fake_summary_dict): # Should fail because a task failed fake_summary_dict.return_value = self._create_summary_dict_with({"ever_failed": [self.task_a], "failed": [self.task_a]}) luigi_run_result = self._run_interface(detailed_summary=True) self.assertEqual(luigi_run_result.status, LuigiStatusCode.FAILED) @patch(_summary_dict_module_path()) def test_that_status_is_failed_with_scheduling_failure(self, fake_summary_dict): # Failed task and also a scheduling error fake_summary_dict.return_value = self._create_summary_dict_with( {"ever_failed": [self.task_a], "failed": [self.task_a], "scheduling_error": [self.task_b]} ) luigi_run_result = self._run_interface(detailed_summary=True) self.assertEqual(luigi_run_result.status, LuigiStatusCode.FAILED_AND_SCHEDULING_FAILED) @patch(_summary_dict_module_path()) def test_that_status_is_scheduling_failed_with_one_scheduling_error(self, fake_summary_dict): # Scheduling error for at least one task fake_summary_dict.return_value = self._create_summary_dict_with({"scheduling_error": [self.task_b]}) luigi_run_result = self._run_interface(detailed_summary=True) self.assertEqual(luigi_run_result.status, LuigiStatusCode.SCHEDULING_FAILED) @patch(_summary_dict_module_path()) def test_that_status_is_not_run_with_one_task_not_run(self, fake_summary_dict): # At least one of the tasks was not run fake_summary_dict.return_value = self._create_summary_dict_with({"not_run": [self.task_a]}) luigi_run_result = self._run_interface(detailed_summary=True) self.assertEqual(luigi_run_result.status, LuigiStatusCode.NOT_RUN) @patch(_summary_dict_module_path()) def test_that_status_is_missing_ext_with_one_task_with_missing_external_dependency(self, fake_summary_dict): # Missing external dependency for at least one task fake_summary_dict.return_value = self._create_summary_dict_with({"still_pending_ext": [self.task_a]}) luigi_run_result = self._run_interface(detailed_summary=True) self.assertEqual(luigi_run_result.status, LuigiStatusCode.MISSING_EXT) def test_stops_worker_on_add_exception(self): worker = MagicMock() self.worker_scheduler_factory.create_worker = Mock(return_value=worker) worker.add = Mock(side_effect=AttributeError) self.assertRaises(AttributeError, self._run_interface) self.assertTrue(worker.__exit__.called) def test_stops_worker_on_run_exception(self): worker = MagicMock() self.worker_scheduler_factory.create_worker = Mock(return_value=worker) worker.add = Mock(side_effect=[True, True]) worker.run = Mock(side_effect=AttributeError) self.assertRaises(AttributeError, self._run_interface) self.assertTrue(worker.__exit__.called) def test_just_run_main_task_cls(self): class MyTestTask(luigi.Task): pass class MyOtherTestTask(luigi.Task): my_param = luigi.Parameter() with patch.object(sys, "argv", ["my_module.py", "--no-lock", "--local-scheduler"]): luigi.run(main_task_cls=MyTestTask) with patch.object(sys, "argv", ["my_module.py", "--no-lock", "--my-param", "my_value", "--local-scheduler"]): luigi.run(main_task_cls=MyOtherTestTask) def _run_interface(self, **env_params): return luigi.interface.build([self.task_a, self.task_b], worker_scheduler_factory=self.worker_scheduler_factory, **env_params) class CoreConfigTest(LuigiTestCase): @with_config({}) def test_parallel_scheduling_processes_default(self): self.assertEqual(0, core().parallel_scheduling_processes) @with_config({"core": {"parallel-scheduling-processes": "1234"}}) def test_parallel_scheduling_processes(self): from luigi.interface import core self.assertEqual(1234, core().parallel_scheduling_processes) ================================================ FILE: test/list_parameter_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import json import mock import pytest from helpers import in_parse, unittest from jsonschema import Draft4Validator from jsonschema.exceptions import ValidationError import luigi class ListParameterTask(luigi.Task): param = luigi.ListParameter() class ListParameterTest(unittest.TestCase): _list = [1, "one", True] def test_parse(self): d = luigi.ListParameter().parse(json.dumps(ListParameterTest._list)) self.assertEqual(d, ListParameterTest._list) def test_serialize(self): d = luigi.ListParameter().serialize(ListParameterTest._list) self.assertEqual(d, '[1, "one", true]') def test_list_serialize_parse(self): a = luigi.ListParameter() b_list = [1, 2, 3] self.assertEqual(b_list, a.parse(a.serialize(b_list))) def test_parse_interface(self): in_parse(["ListParameterTask", "--param", '[1, "one", true]'], lambda task: self.assertEqual(task.param, tuple(ListParameterTest._list))) def test_serialize_task(self): t = ListParameterTask(ListParameterTest._list) self.assertEqual(str(t), 'ListParameterTask(param=[1, "one", true])') def test_parse_invalid_input(self): self.assertRaises(ValueError, lambda: luigi.ListParameter().parse('{"invalid"}')) def test_hash_normalize(self): self.assertRaises(TypeError, lambda: hash(luigi.ListParameter().parse('"NOT A LIST"'))) a = luigi.ListParameter().normalize([0]) b = luigi.ListParameter().normalize([0]) self.assertEqual(hash(a), hash(b)) def test_schema(self): a = luigi.ListParameter( schema={ "type": "array", "items": { "type": "number", "minimum": 0, "maximum": 10, }, "minItems": 1, } ) # Check that the default value is validated with pytest.raises(ValidationError, match=r"'INVALID_ATTRIBUTE' is not of type 'number'"): a.normalize(["INVALID_ATTRIBUTE"]) # Check that empty list is not valid with pytest.raises(ValidationError): a.normalize([]) # Check that valid lists work valid_list = [1, 2, 3] a.normalize(valid_list) # Check that invalid lists raise correct errors invalid_list_type = ["NOT AN INT"] invalid_list_value = [-999, 4] with pytest.raises(ValidationError, match="'NOT AN INT' is not of type 'number'"): a.normalize(invalid_list_type) with pytest.raises(ValidationError, match="-999 is less than the minimum of 0"): a.normalize(invalid_list_value) # Check that warnings are properly emitted with mock.patch("luigi.parameter._JSONSCHEMA_ENABLED", False): with pytest.warns( UserWarning, match=("The 'jsonschema' package is not installed so the parameter can not be validated even though a schema is given.") ): luigi.ListParameter(schema={"type": "array", "items": {"type": "number"}}) # Test with a custom validator validator = Draft4Validator( schema={ "type": "array", "items": { "type": "number", "minimum": 0, "maximum": 10, }, "minItems": 1, } ) c = luigi.DictParameter(schema=validator) c.normalize(valid_list) with pytest.raises( ValidationError, match=r"'INVALID_ATTRIBUTE' is not of type 'number'", ): c.normalize(["INVALID_ATTRIBUTE"]) # Test with frozen data frozen_data = luigi.freezing.recursively_freeze(valid_list) c.normalize(frozen_data) ================================================ FILE: test/local_target_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import bz2 import gzip import io import itertools import os import random import shutil import sys from errno import EEXIST, EXDEV import mock from helpers import unittest from target_test import FileSystemTargetTestMixin import luigi.format from luigi import LocalTarget from luigi.local_target import LocalFileSystem from luigi.target import FileAlreadyExists, MissingParentDirectory class LocalTargetTest(unittest.TestCase, FileSystemTargetTestMixin): PATH_PREFIX = "/tmp/test.txt" def setUp(self): self.path = self.PATH_PREFIX + "-" + str(self.id()) self.copy = self.PATH_PREFIX + "-copy-" + str(self.id()) if os.path.exists(self.path): os.remove(self.path) if os.path.exists(self.copy): os.remove(self.copy) def tearDown(self): if os.path.exists(self.path): os.remove(self.path) if os.path.exists(self.copy): os.remove(self.copy) def create_target(self, format=None): return LocalTarget(self.path, format=format) def assertCleanUp(self, tmp_path=""): self.assertFalse(os.path.exists(tmp_path)) def test_exists(self): t = self.create_target() p = t.open("w") self.assertEqual(t.exists(), os.path.exists(self.path)) p.close() self.assertEqual(t.exists(), os.path.exists(self.path)) @unittest.skipIf(tuple(sys.version_info) < (3, 4), "only for Python>=3.4") def test_pathlib(self): """Test work with pathlib.Path""" import pathlib path = pathlib.Path(self.path) self.assertFalse(path.exists()) target = LocalTarget(path) self.assertFalse(target.exists()) with path.open("w") as stream: stream.write("test me") self.assertTrue(target.exists()) def test_gzip_with_module(self): t = LocalTarget(self.path, luigi.format.Gzip) p = t.open("w") test_data = b"test" p.write(test_data) print(self.path) self.assertFalse(os.path.exists(self.path)) p.close() self.assertTrue(os.path.exists(self.path)) # Using gzip module as validation f = gzip.open(self.path, "r") self.assertTrue(test_data == f.read()) f.close() # Verifying our own gzip reader f = LocalTarget(self.path, luigi.format.Gzip).open("r") self.assertTrue(test_data == f.read()) f.close() def test_bzip2(self): t = LocalTarget(self.path, luigi.format.Bzip2) p = t.open("w") test_data = b"test" p.write(test_data) print(self.path) self.assertFalse(os.path.exists(self.path)) p.close() self.assertTrue(os.path.exists(self.path)) # Using bzip module as validation f = bz2.BZ2File(self.path, "r") self.assertTrue(test_data == f.read()) f.close() # Verifying our own bzip2 reader f = LocalTarget(self.path, luigi.format.Bzip2).open("r") self.assertTrue(test_data == f.read()) f.close() def test_copy(self): t = LocalTarget(self.path) f = t.open("w") test_data = "test" f.write(test_data) f.close() self.assertTrue(os.path.exists(self.path)) self.assertFalse(os.path.exists(self.copy)) t.copy(self.copy) self.assertTrue(os.path.exists(self.path)) self.assertTrue(os.path.exists(self.copy)) self.assertEqual(t.open("r").read(), LocalTarget(self.copy).open("r").read()) def test_move(self): t = LocalTarget(self.path) f = t.open("w") test_data = "test" f.write(test_data) f.close() self.assertTrue(os.path.exists(self.path)) self.assertFalse(os.path.exists(self.copy)) t.move(self.copy) self.assertFalse(os.path.exists(self.path)) self.assertTrue(os.path.exists(self.copy)) def test_move_across_filesystems(self): t = LocalTarget(self.path) with t.open("w") as f: f.write("test_data") def rename_across_filesystems(src, dst): err = OSError() err.errno = EXDEV raise err real_rename = os.replace def mockreplace(src, dst): if "-across-fs" in src: real_rename(src, dst) else: rename_across_filesystems(src, dst) copy = "%s-across-fs" % self.copy with mock.patch("os.replace", mockreplace): t.move(copy) self.assertFalse(os.path.exists(self.path)) self.assertTrue(os.path.exists(copy)) self.assertEqual("test_data", LocalTarget(copy).open("r").read()) def test_format_chain(self): UTF8WIN = luigi.format.TextFormat(encoding="utf8", newline="\r\n") t = LocalTarget(self.path, UTF8WIN >> luigi.format.Gzip) a = "我é\nçф" with t.open("w") as f: f.write(a) f = gzip.open(self.path, "rb") b = f.read() f.close() self.assertEqual(b"\xe6\x88\x91\xc3\xa9\r\n\xc3\xa7\xd1\x84", b) def test_format_chain_reverse(self): t = LocalTarget(self.path, luigi.format.UTF8 >> luigi.format.Gzip) f = gzip.open(self.path, "wb") f.write(b"\xe6\x88\x91\xc3\xa9\r\n\xc3\xa7\xd1\x84") f.close() with t.open("r") as f: b = f.read() self.assertEqual("我é\nçф", b) @mock.patch("os.linesep", "\r\n") def test_format_newline(self): t = LocalTarget(self.path, luigi.format.SysNewLine) with t.open("w") as f: f.write(b"a\rb\nc\r\nd") with t.open("r") as f: b = f.read() with open(self.path, "rb") as f: c = f.read() self.assertEqual(b"a\nb\nc\nd", b) self.assertEqual(b"a\r\nb\r\nc\r\nd", c) def theoretical_io_modes(self, rwax="rwax", bt=["", "b", "t"], plus=["", "+"]): p = itertools.product(rwax, plus, bt) return {"".join(c) for c in list(itertools.chain.from_iterable([itertools.permutations(m) for m in p]))} def valid_io_modes(self, *a, **kw): modes = set() t = LocalTarget(is_tmp=True) t.open("w").close() for mode in self.theoretical_io_modes(*a, **kw): try: io.FileIO(t.path, mode).close() except ValueError: pass except IOError as err: if err.errno == EEXIST: modes.add(mode) else: raise else: modes.add(mode) return modes def valid_write_io_modes_for_luigi(self): return self.valid_io_modes("w", plus=[""]) def valid_read_io_modes_for_luigi(self): return self.valid_io_modes("r", plus=[""]) def invalid_io_modes_for_luigi(self): return self.valid_io_modes().difference(self.valid_write_io_modes_for_luigi(), self.valid_read_io_modes_for_luigi()) def test_open_modes(self): t = LocalTarget(is_tmp=True) print("Valid write mode:", end=" ") for mode in self.valid_write_io_modes_for_luigi(): print(mode, end=" ") p = t.open(mode) p.close() print() print("Valid read mode:", end=" ") for mode in self.valid_read_io_modes_for_luigi(): print(mode, end=" ") p = t.open(mode) p.close() print() print("Invalid mode:", end=" ") for mode in self.invalid_io_modes_for_luigi(): print(mode, end=" ") self.assertRaises(Exception, t.open, mode) print() class LocalTargetCreateDirectoriesTest(LocalTargetTest): path = "/tmp/%s/xyz/test.txt" % random.randint(0, 999999999) copy = "/tmp/%s/xyz_2/copy.txt" % random.randint(0, 999999999) class LocalTargetRelativeTest(LocalTargetTest): # We had a bug that caused relative file paths to fail, adding test for it path = "test.txt" copy = "copy.txt" class TmpFileTest(unittest.TestCase): def test_tmp(self): t = LocalTarget(is_tmp=True) self.assertFalse(t.exists()) self.assertFalse(os.path.exists(t.path)) p = t.open("w") print("test", file=p) self.assertFalse(t.exists()) self.assertFalse(os.path.exists(t.path)) p.close() self.assertTrue(t.exists()) self.assertTrue(os.path.exists(t.path)) q = t.open("r") self.assertEqual(q.readline(), "test\n") q.close() path = t.path del t # should remove the underlying file self.assertFalse(os.path.exists(path)) class FileSystemTest(unittest.TestCase): path = "/tmp/luigi-test-dir" fs = LocalFileSystem() def setUp(self): if os.path.exists(self.path): shutil.rmtree(self.path) def tearDown(self): self.setUp() def test_copy(self): src = os.path.join(self.path, "src.txt") dest = os.path.join(self.path, "newdir", "dest.txt") LocalTarget(src).open("w").close() self.fs.copy(src, dest) self.assertTrue(os.path.exists(src)) self.assertTrue(os.path.exists(dest)) def test_mkdir(self): testpath = os.path.join(self.path, "foo/bar") self.assertRaises(MissingParentDirectory, self.fs.mkdir, testpath, parents=False) self.fs.mkdir(testpath) self.assertTrue(os.path.exists(testpath)) self.assertTrue(self.fs.isdir(testpath)) self.assertRaises(FileAlreadyExists, self.fs.mkdir, testpath, raise_if_exists=True) def test_exists(self): self.assertFalse(self.fs.exists(self.path)) os.mkdir(self.path) self.assertTrue(self.fs.exists(self.path)) self.assertTrue(self.fs.isdir(self.path)) def test_listdir(self): os.mkdir(self.path) with open(self.path + "/file", "w"): pass self.assertTrue([self.path + "/file"], list(self.fs.listdir(self.path + "/"))) def test_move_to_new_dir(self): # Regression test for a bug in LocalFileSystem.move src = os.path.join(self.path, "src.txt") dest = os.path.join(self.path, "newdir", "dest.txt") LocalTarget(src).open("w").close() self.fs.move(src, dest) self.assertTrue(os.path.exists(dest)) class DestructorTest(unittest.TestCase): def test_destructor(self): # LocalTarget might not be fully initialised if an exception is thrown in the constructor of LocalTarget or a # subclass. The destructor can't expect attributes to be initialised. t = LocalTarget(is_tmp=True) del t.is_tmp t.__del__() ================================================ FILE: test/lock_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import os import subprocess import tempfile import mock from helpers import unittest from tenacity import retry, retry_if_result, stop_after_attempt, wait_exponential import luigi import luigi.lock import luigi.notifications luigi.notifications.DEBUG = True class TestCmd(unittest.TestCase): def test_getpcmd(self): def _is_empty(cmd): return cmd == "" # for CI stability, add retring @retry(retry=retry_if_result(_is_empty), wait=wait_exponential(multiplier=0.2, min=0.1, max=3), stop=stop_after_attempt(3)) def _getpcmd(pid): return luigi.lock.getpcmd(pid) if os.name == "nt": command = ["ping", "1.1.1.1", "-w", "1000"] else: command = ["sleep", "1"] external_process = subprocess.Popen(command) result = _getpcmd(external_process.pid) self.assertTrue(result.strip() in ["sleep 1", "[sleep]", "ping 1.1.1.1 -w 1000"]) external_process.kill() class LockTest(unittest.TestCase): def setUp(self): self.pid_dir = tempfile.mkdtemp() self.pid, self.cmd, self.pid_file = luigi.lock.get_info(self.pid_dir) def tearDown(self): if os.path.exists(self.pid_file): os.remove(self.pid_file) os.rmdir(self.pid_dir) def test_get_info(self): def _is_empty(result): return result[1] == "" # cmd is empty # for CI stability, add retring @retry(retry=retry_if_result(_is_empty), wait=wait_exponential(multiplier=0.2, min=0.1, max=3), stop=stop_after_attempt(3)) def _get_info(pid_dir, pid): return luigi.lock.get_info(pid_dir, pid) try: p = subprocess.Popen(["yes", "à我ф"], stdout=subprocess.PIPE) pid, cmd, pid_file = _get_info(self.pid_dir, p.pid) finally: p.kill() self.assertEqual(cmd, "yes à我ф") def test_acquiring_free_lock(self): acquired = luigi.lock.acquire_for(self.pid_dir) self.assertTrue(acquired) def test_acquiring_taken_lock(self): with open(self.pid_file, "w") as f: f.write("%d\n" % (self.pid,)) acquired = luigi.lock.acquire_for(self.pid_dir) self.assertFalse(acquired) def test_acquiring_partially_taken_lock(self): with open(self.pid_file, "w") as f: f.write("%d\n" % (self.pid,)) acquired = luigi.lock.acquire_for(self.pid_dir, 2) self.assertTrue(acquired) s = os.stat(self.pid_file) self.assertEqual(s.st_mode & 0o700, 0o700) def test_acquiring_lock_from_missing_process(self): fake_pid = 99999 with open(self.pid_file, "w") as f: f.write("%d\n" % (fake_pid,)) acquired = luigi.lock.acquire_for(self.pid_dir) self.assertTrue(acquired) s = os.stat(self.pid_file) self.assertEqual(s.st_mode & 0o700, 0o700) @mock.patch("os.kill") def test_take_lock_with_kill(self, kill_fn): with open(self.pid_file, "w") as f: f.write("%d\n" % (self.pid,)) kill_signal = 77777 acquired = luigi.lock.acquire_for(self.pid_dir, kill_signal=kill_signal) self.assertTrue(acquired) kill_fn.assert_called_once_with(self.pid, kill_signal) @mock.patch("os.kill") @mock.patch("luigi.lock.getpcmd") def test_take_lock_has_only_one_extra_life(self, getpcmd, kill_fn): def side_effect(pid): if pid in [self.pid, self.pid + 1, self.pid + 2]: return self.cmd # We could return something else too, actually else: return "echo something_else" getpcmd.side_effect = side_effect with open(self.pid_file, "w") as f: f.write("{}\n{}\n".format(self.pid + 1, self.pid + 2)) kill_signal = 77777 acquired = luigi.lock.acquire_for(self.pid_dir, kill_signal=kill_signal) self.assertFalse(acquired) # So imagine +2 was runnig first, then +1 was run with --take-lock kill_fn.assert_any_call(self.pid + 1, kill_signal) kill_fn.assert_any_call(self.pid + 2, kill_signal) @mock.patch("luigi.lock.getpcmd") def test_cleans_old_pid_entries(self, getpcmd): assert self.pid > 10 # I've never seen so low pids so SAME_ENTRIES = {1, 2, 3, 4, 5, self.pid} ALL_ENTRIES = SAME_ENTRIES | {6, 7, 8, 9, 10} def side_effect(pid): if pid in SAME_ENTRIES: return self.cmd # We could return something else too, actually elif pid == 8: return None else: return "echo something_else" getpcmd.side_effect = side_effect with open(self.pid_file, "w") as f: f.writelines("{}\n".format(pid) for pid in ALL_ENTRIES) acquired = luigi.lock.acquire_for(self.pid_dir, num_available=100) self.assertTrue(acquired) with open(self.pid_file, "r") as f: self.assertEqual({int(pid_str.strip()) for pid_str in f}, SAME_ENTRIES) ================================================ FILE: test/metrics_test.py ================================================ import unittest import luigi.metrics as metrics from luigi.contrib.datadog_metric import DatadogMetricsCollector from luigi.contrib.prometheus_metric import PrometheusMetricsCollector class TestMetricsCollectors(unittest.TestCase): def test_default_value(self): collector = metrics.MetricsCollectors.default output = metrics.MetricsCollectors.get(collector) assert type(output) is metrics.NoMetricsCollector def test_datadog_value(self): collector = metrics.MetricsCollectors.datadog output = metrics.MetricsCollectors.get(collector) assert type(output) is DatadogMetricsCollector def test_prometheus_value(self): collector = metrics.MetricsCollectors.prometheus output = metrics.MetricsCollectors.get(collector) assert type(output) is PrometheusMetricsCollector def test_none_value(self): collector = metrics.MetricsCollectors.none output = metrics.MetricsCollectors.get(collector) assert type(output) is metrics.NoMetricsCollector def test_other_value(self): collector = "junk" with self.assertRaises(ValueError) as context: metrics.MetricsCollectors.get(collector) assert ("MetricsCollectors value ' junk ' isn't supported") in str(context.exception) ================================================ FILE: test/mock_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from helpers import unittest from luigi.format import Nop from luigi.mock import MockFileSystem, MockTarget class MockFileTest(unittest.TestCase): def test_1(self): t = MockTarget("test") p = t.open("w") print("test", file=p) p.close() q = t.open("r") self.assertEqual(list(q), ["test\n"]) q.close() def test_with(self): t = MockTarget("foo") with t.open("w") as b: b.write("bar") with t.open("r") as b: self.assertEqual(list(b), ["bar"]) def test_bytes(self): t = MockTarget("foo", format=Nop) with t.open("wb") as b: b.write(b"bar") with t.open("rb") as b: self.assertEqual(list(b), [b"bar"]) def test_default_mode_value(self): t = MockTarget("foo") with t.open("w") as b: b.write("bar") with t.open() as b: self.assertEqual(list(b), ["bar"]) def test_mode_none_error(self): t = MockTarget("foo") with self.assertRaises(TypeError): with t.open(None) as b: b.write("bar") # That should work in python2 because of the autocast # That should work in python3 because the default format is Text def test_unicode(self): t = MockTarget("foo") with t.open("w") as b: b.write("bar") with t.open("r") as b: self.assertEqual(b.read(), "bar") class MockFileSystemTest(unittest.TestCase): fs = MockFileSystem() def _touch(self, path): t = MockTarget(path) with t.open("w"): pass def setUp(self): self.fs.clear() self.path = "/tmp/foo" self.path2 = "/tmp/bar" self.path3 = "/tmp/foobar" self._touch(self.path) self._touch(self.path2) def test_copy(self): self.fs.copy(self.path, self.path3) self.assertTrue(self.fs.exists(self.path)) self.assertTrue(self.fs.exists(self.path3)) def test_exists(self): self.assertTrue(self.fs.exists(self.path)) def test_remove(self): self.fs.remove(self.path) self.assertFalse(self.fs.exists(self.path)) def test_remove_recursive(self): self.fs.remove("/tmp", recursive=True) self.assertFalse(self.fs.exists(self.path)) self.assertFalse(self.fs.exists(self.path2)) def test_rename(self): self.fs.rename(self.path, self.path3) self.assertFalse(self.fs.exists(self.path)) self.assertTrue(self.fs.exists(self.path3)) def test_listdir(self): self.assertEqual(sorted([self.path, self.path2]), sorted(self.fs.listdir("/tmp"))) ================================================ FILE: test/most_common_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from helpers import unittest from luigi.tools.range import most_common class MostCommonTest(unittest.TestCase): def setUp(self): self.runs = [([1], (1, 1)), ([1, 1], (1, 2)), ([1, 1, 2], (1, 2)), ([1, 1, 2, 2, 2], (2, 3))] def test_runs(self): for args, result in self.runs: actual = most_common(args) expected = result self.assertEqual(expected, actual) ================================================ FILE: test/mypy_test.py ================================================ import sys import tempfile import unittest from mypy import api def _run_mypy(test_code: str): with tempfile.NamedTemporaryFile(suffix=".py") as test_file: test_file.write(test_code.encode("utf-8")) test_file.flush() return api.run( [ "--no-incremental", "--cache-dir=/dev/null", "--show-traceback", "--config-file", "test/testconfig/pyproject.toml", test_file.name, ] ) class TestMyMypyPlugin(unittest.TestCase): def test_plugin_no_issue(self): if sys.version_info[:2] < (3, 8): return test_code = """ from datetime import date, datetime, timedelta from enum import Enum import luigi from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Type from uuid import UUID class MyEnum(Enum): A = 1 B = 2 C = 3 class UUIDParameter(luigi.Parameter[UUID]): def parse(self, s): return UUID(s) class OtherTask(luigi.Task): pass class MyTask(luigi.Task): bool_p: bool = luigi.BoolParameter() choice_int_p: int = luigi.parameter.ChoiceParameter(choices=[1, 2, 3]) choice_list_int_p: Tuple[int, ...] = luigi.parameter.ChoiceListParameter(choices=[1, 2, 3]) choice_list_str_p: Tuple[str, ...] = luigi.parameter.ChoiceListParameter(choices=["foo", "bar", "baz"]) choice_str_p: str = luigi.parameter.ChoiceParameter(choices=["foo", "bar", "baz"]) date_p: date = luigi.DateParameter() datetime_p: datetime = luigi.DateSecondParameter() dict_p: Dict[str, str] = luigi.DictParameter() enum_p: MyEnum = luigi.parameter.EnumParameter(enum=MyEnum) enums_p: Tuple[MyEnum, ...] = luigi.parameter.EnumListParameter(enum=MyEnum) int_p: int = luigi.IntParameter() list_float_p: Tuple[Any, ...] = luigi.ListParameter() numeric_p: float = luigi.NumericalParameter(var_type=float, min_value=-3.0, max_value=7.0) opt_p: Optional[str] = luigi.OptionalParameter() path_p: Path = luigi.PathParameter() str_p: str = luigi.Parameter() str_p_default: str = luigi.Parameter(default="baz") task_p: Type[luigi.Task] = luigi.TaskParameter() timedelta_p: timedelta = luigi.TimeDeltaParameter() tuple_int_p: Tuple[Any, ...] = luigi.TupleParameter() uuid_p: UUID = UUIDParameter() MyTask( bool_p=True, choice_int_p=3, choice_list_int_p=(2, 3), choice_list_str_p=("foo", "baz"), choice_str_p="foo", date_p=date.today(), datetime_p=datetime.now(), dict_p={"foo": "bar"}, enum_p=MyEnum.B, enums_p=(MyEnum.A, MyEnum.C), int_p=1, list_float_p=(0.1, 0.2), numeric_p=4.0, opt_p=None, path_p=Path("/tmp"), str_p='bar', task_p=OtherTask, timedelta_p=timedelta(hours=1), tuple_int_p=(1, 2), uuid_p=UUID("9b0591d7-a167-4978-bc6d-41f7d84a288c"), ) """ stdout, stderr, exitcode = _run_mypy(test_code) self.assertEqual( exitcode, 0, f"mypy plugin error occurred:\nstdout: {stdout}\nstderr: {stderr}", ) self.assertIn("Success: no issues found", stdout) def test_plugin_invalid_arg(self): if sys.version_info[:2] < (3, 8): return test_code = """ import luigi class MyTask(luigi.Task): foo: int = luigi.IntParameter() bar: str = luigi.Parameter() baz: str = luigi.Parameter(default=1) # invalid assignment to str with default value int # issue: # - foo is int # - unknown is unknown parameter # - baz is invalid assignment to str with default value int MyTask(foo='1', bar="bar", unknown="unknown") """ stdout, stderr, exitcode = _run_mypy(test_code) self.assertEqual( exitcode, 1, f"mypy plugin error occurred:\nstdout: {stdout}\nstderr: {stderr}", ) self.assertIn( 'error: Incompatible types in assignment (expression has type "int", variable has type "str") [assignment]', stdout, ) # check baz assignment self.assertIn( 'error: Argument "foo" to "MyTask" has incompatible type "str"; expected "int" [arg-type]', stdout, ) # check foo argument self.assertIn( 'error: Unexpected keyword argument "unknown" for "MyTask" [call-arg]', stdout, ) # check unknown argument self.assertIn("Found 3 errors in 1 file (checked 1 source file)", stdout) def test_plugin_custom_parameter_subclass_without_default_arg(self): """Test for issue #3376: Custom Parameter subclass without 'default' in __init__""" if sys.version_info[:2] < (3, 8): return test_code = """ import luigi class CustomPathParameter(luigi.PathParameter): \"\"\"A PathParameter subclass that doesn't expose 'default' in its signature.\"\"\" def __init__(self, **kwargs): super().__init__(**kwargs) class MyTask(luigi.Task): path = CustomPathParameter() """ stdout, stderr, exitcode = _run_mypy(test_code) self.assertEqual( exitcode, 0, f"mypy plugin error occurred:\nstdout: {stdout}\nstderr: {stderr}", ) self.assertIn("Success: no issues found", stdout) def test_plugin_parameter_type_annotation(self): """Test that Parameter types can be used as type annotations. Users should be able to write: foo: luigi.IntParameter = luigi.IntParameter() bar: luigi.Parameter[str] = luigi.Parameter() """ if sys.version_info[:2] < (3, 8): return test_code = """ import luigi class MyTask(luigi.Task): foo: luigi.IntParameter = luigi.IntParameter() bar: luigi.StrParameter = luigi.StrParameter() MyTask(foo=1, bar='2') """ stdout, stderr, exitcode = _run_mypy(test_code) self.assertEqual( exitcode, 0, f"mypy plugin error occurred:\nstdout: {stdout}\nstderr: {stderr}", ) self.assertIn("Success: no issues found", stdout) def test_plugin_parameter_type_annotation_invalid_arg(self): """Test that Parameter type annotations catch type errors in __init__ args. MyTask(foo='1', bar='2') should error because foo expects int, not str. """ if sys.version_info[:2] < (3, 8): return test_code = """ import luigi class MyTask(luigi.Task): foo: luigi.IntParameter = luigi.IntParameter() bar: luigi.StrParameter = luigi.StrParameter() MyTask(foo='1', bar='2') """ stdout, stderr, exitcode = _run_mypy(test_code) self.assertEqual( exitcode, 1, f"Expected mypy error but got:\nstdout: {stdout}\nstderr: {stderr}", ) self.assertIn( 'error: Argument "foo" to "MyTask" has incompatible type "str"; expected "int"', stdout, ) self.assertIn("Found 1 error in 1 file (checked 1 source file)", stdout) ================================================ FILE: test/notifications_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import socket import sys import mock from helpers import unittest, with_config import luigi from luigi import notifications from luigi.notifications import generate_email from luigi.scheduler import Scheduler from luigi.worker import Worker class TestEmail(unittest.TestCase): def testEmailNoPrefix(self): self.assertEqual("subject", notifications._prefix("subject")) @with_config({"email": {"prefix": "[prefix]"}}) def testEmailPrefix(self): self.assertEqual("[prefix] subject", notifications._prefix("subject")) class TestException(Exception): pass class TestTask(luigi.Task): foo = luigi.Parameter() bar = luigi.Parameter() class FailSchedulingTask(TestTask): def requires(self): raise TestException("Oops!") def run(self): pass def complete(self): return False class FailRunTask(TestTask): def run(self): raise TestException("Oops!") def complete(self): return False class ExceptionFormatTest(unittest.TestCase): def setUp(self): self.sch = Scheduler() def test_fail_run(self): task = FailRunTask(foo="foo", bar="bar") self._run_task(task) def test_fail_run_html(self): task = FailRunTask(foo="foo", bar="bar") self._run_task_html(task) def test_fail_schedule(self): task = FailSchedulingTask(foo="foo", bar="bar") self._run_task(task) def test_fail_schedule_html(self): task = FailSchedulingTask(foo="foo", bar="bar") self._run_task_html(task) @with_config({"email": {"receiver": "nowhere@example.com", "prefix": "[TEST] "}}) @mock.patch("luigi.notifications.send_error_email") def _run_task(self, task, mock_send): with Worker(scheduler=self.sch) as w: w.add(task) w.run() self.assertEqual(mock_send.call_count, 1) args, kwargs = mock_send.call_args self._check_subject(args[0], task) self._check_body(args[1], task, html=False) @with_config({"email": {"receiver": "nowhere@axample.com", "prefix": "[TEST] ", "format": "html"}}) @mock.patch("luigi.notifications.send_error_email") def _run_task_html(self, task, mock_send): with Worker(scheduler=self.sch) as w: w.add(task) w.run() self.assertEqual(mock_send.call_count, 1) args, kwargs = mock_send.call_args self._check_subject(args[0], task) self._check_body(args[1], task, html=True) def _check_subject(self, subject, task): self.assertIn(str(task), subject) def _check_body(self, body, task, html=False): if html: self.assertIn("name{}".format(task.task_family), body) self.assertIn('
    {}{}".format(param, value), body) else: self.assertIn("Name: {}\n".format(task.task_family), body) self.assertIn("Parameters:\n", body) self.assertIn("TestException: Oops!", body) for param, value in task.param_kwargs.items(): self.assertIn("{}: {}\n".format(param, value), body) @with_config({"email": {"receiver": "a@a.a"}}) def testEmailRecipients(self): self.assertCountEqual(notifications._email_recipients(), ["a@a.a"]) self.assertCountEqual(notifications._email_recipients("b@b.b"), ["a@a.a", "b@b.b"]) self.assertCountEqual(notifications._email_recipients(["b@b.b", "c@c.c"]), ["a@a.a", "b@b.b", "c@c.c"]) @with_config({"email": {}}, replace_sections=True) def testEmailRecipientsNoConfig(self): self.assertCountEqual(notifications._email_recipients(), []) self.assertCountEqual(notifications._email_recipients("a@a.a"), ["a@a.a"]) self.assertCountEqual(notifications._email_recipients(["a@a.a", "b@b.b"]), ["a@a.a", "b@b.b"]) def test_generate_unicode_email(self): generate_email( sender="test@example.com", subject="sübjéct", message="你好", recipients=["receiver@example.com"], image_png=None, ) class NotificationFixture: """ Defines API and message fixture. config, sender, subject, message, recipients, image_png """ sender = "luigi@unittest" subject = "Oops!" message = """A multiline message.""" recipients = ["noone@nowhere.no", "phantom@opera.fr"] image_png = None notification_args = [sender, subject, message, recipients, image_png] mocked_email_msg = """Content-Type: multipart/related; boundary="===============0998157881==" MIME-Version: 1.0 Subject: Oops! From: luigi@unittest To: noone@nowhere.no,phantom@opera.fr --===============0998157881== MIME-Version: 1.0 Content-Transfer-Encoding: 7bit Content-Type: text/plain; charset="utf-8" A multiline message. --===============0998157881==--""" class TestSMTPEmail(unittest.TestCase, NotificationFixture): """ Tests sending SMTP email. """ def setUp(self): sys.modules["smtplib"] = mock.MagicMock() import smtplib # noqa: F401 def tearDown(self): del sys.modules["smtplib"] @with_config( { "smtp": { "ssl": "False", "host": "my.smtp.local", "port": "999", "local_hostname": "ptms", "timeout": "1200", "username": "Robin", "password": "dooH", "no_tls": "False", } } ) def test_sends_smtp_email(self): """ Call notifications.send_email_smtp with fixture parameters with smtp_without_tls set to False and check that sendmail is properly called. """ smtp_kws = {"host": "my.smtp.local", "port": 999, "local_hostname": "ptms", "timeout": 1200} with mock.patch("smtplib.SMTP") as SMTP: with mock.patch("luigi.notifications.generate_email") as generate_email: generate_email.return_value.as_string.return_value = self.mocked_email_msg notifications.send_email_smtp(*self.notification_args) SMTP.assert_called_once_with(**smtp_kws) SMTP.return_value.login.assert_called_once_with("Robin", "dooH") SMTP.return_value.starttls.assert_called_once_with() SMTP.return_value.sendmail.assert_called_once_with(self.sender, self.recipients, self.mocked_email_msg) @with_config( { "smtp": { "ssl": "False", "host": "my.smtp.local", "port": "999", "local_hostname": "ptms", "timeout": "1200", "username": "Robin", "password": "dooH", "no_tls": "True", } } ) def test_sends_smtp_email_without_tls(self): """ Call notifications.send_email_smtp with fixture parameters with no_tls set to True and check that sendmail is properly called without also calling starttls. """ smtp_kws = {"host": "my.smtp.local", "port": 999, "local_hostname": "ptms", "timeout": 1200} with mock.patch("smtplib.SMTP") as SMTP: with mock.patch("luigi.notifications.generate_email") as generate_email: generate_email.return_value.as_string.return_value = self.mocked_email_msg notifications.send_email_smtp(*self.notification_args) SMTP.assert_called_once_with(**smtp_kws) self.assertEqual(SMTP.return_value.starttls.called, False) SMTP.return_value.login.assert_called_once_with("Robin", "dooH") SMTP.return_value.sendmail.assert_called_once_with(self.sender, self.recipients, self.mocked_email_msg) @with_config( { "smtp": { "ssl": "False", "host": "my.smtp.local", "port": "999", "local_hostname": "ptms", "timeout": "1200", "username": "Robin", "password": "dooH", "no_tls": "True", } } ) def test_sends_smtp_email_exceptions(self): """ Call notifications.send_email_smtp when it cannot connect to smtp server (socket.error) starttls. """ smtp_kws = {"host": "my.smtp.local", "port": 999, "local_hostname": "ptms", "timeout": 1200} with mock.patch("smtplib.SMTP") as SMTP: with mock.patch("luigi.notifications.generate_email") as generate_email: SMTP.side_effect = socket.error() generate_email.return_value.as_string.return_value = self.mocked_email_msg try: notifications.send_email_smtp(*self.notification_args) except socket.error: self.fail("send_email_smtp() raised expection unexpectedly") SMTP.assert_called_once_with(**smtp_kws) self.assertEqual(notifications.generate_email.called, False) self.assertEqual(SMTP.sendemail.called, False) class TestSendgridEmail(unittest.TestCase, NotificationFixture): """ Tests sending Sendgrid email. """ def setUp(self): sys.modules["sendgrid"] = mock.MagicMock() import sendgrid # noqa: F401 def tearDown(self): del sys.modules["sendgrid"] @with_config({"sendgrid": {"apikey": "456abcdef123"}}) def test_sends_sendgrid_email(self): """ Call notifications.send_email_sendgrid with fixture parameters and check that SendGridAPIClient is properly called. """ with mock.patch("sendgrid.SendGridAPIClient") as SendGridAPIClient: notifications.send_email_sendgrid(*self.notification_args) SendGridAPIClient.assert_called_once_with("456abcdef123") self.assertTrue(SendGridAPIClient.return_value.send.called) class TestSESEmail(unittest.TestCase, NotificationFixture): """ Tests sending email through AWS SES. """ def setUp(self): sys.modules["boto3"] = mock.MagicMock() import boto3 # noqa: F401 def tearDown(self): del sys.modules["boto3"] @with_config({}) def test_sends_ses_email(self): """ Call notifications.send_email_ses with fixture parameters and check that boto is properly called. """ with mock.patch("boto3.client") as boto_client: with mock.patch("luigi.notifications.generate_email") as generate_email: generate_email.return_value.as_string.return_value = self.mocked_email_msg notifications.send_email_ses(*self.notification_args) SES = boto_client.return_value SES.send_raw_email.assert_called_once_with(Source=self.sender, Destinations=self.recipients, RawMessage={"Data": self.mocked_email_msg}) class TestSNSNotification(unittest.TestCase, NotificationFixture): """ Tests sending email through AWS SNS. """ def setUp(self): sys.modules["boto3"] = mock.MagicMock() import boto3 # noqa: F401 def tearDown(self): del sys.modules["boto3"] @with_config({}) def test_sends_sns_email(self): """ Call notifications.send_email_sns with fixture parameters and check that boto3 is properly called. """ with mock.patch("boto3.resource") as res: notifications.send_email_sns(*self.notification_args) SNS = res.return_value SNS.Topic.assert_called_once_with(self.recipients[0]) SNS.Topic.return_value.publish.assert_called_once_with(Subject=self.subject, Message=self.message) @with_config({}) def test_sns_subject_is_shortened(self): """ Call notifications.send_email_sns with too long Subject (more than 100 chars) and check that it is cut to length of 100 chars. """ long_subject = ( "Luigi: SanityCheck(regexPattern=aligned-source\\|data-not-older\\|source-chunks-compl,mailFailure=False, mongodb=mongodb://localhost/stats) FAILED" ) with mock.patch("boto3.resource") as res: notifications.send_email_sns(self.sender, long_subject, self.message, self.recipients, self.image_png) SNS = res.return_value SNS.Topic.assert_called_once_with(self.recipients[0]) called_subj = SNS.Topic.return_value.publish.call_args[1]["Subject"] self.assertTrue(len(called_subj) <= 100, "Subject can be max 100 chars long! Found {}.".format(len(called_subj))) class TestNotificationDispatcher(unittest.TestCase, NotificationFixture): """ Test dispatching of notifications on configuration values. """ def check_dispatcher(self, target): """ Call notifications.send_email and test that the proper function was called. """ expected_args = self.notification_args with mock.patch("luigi.notifications.{}".format(target)) as sender: notifications.send_email(self.subject, self.message, self.sender, self.recipients, image_png=self.image_png) self.assertTrue(sender.called) call_args = sender.call_args[0] self.assertEqual(tuple(expected_args), call_args) @with_config({"email": {"force_send": "True", "method": "smtp"}}) def test_smtp(self): return self.check_dispatcher("send_email_smtp") @with_config({"email": {"force_send": "True", "method": "ses"}}) def test_ses(self): return self.check_dispatcher("send_email_ses") @with_config({"email": {"force_send": "True", "method": "sendgrid"}}) def test_sendgrid(self): return self.check_dispatcher("send_email_sendgrid") @with_config({"email": {"force_send": "True", "method": "sns"}}) def test_sns(self): return self.check_dispatcher("send_email_sns") ================================================ FILE: test/numerical_parameter_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from operator import le, lt from helpers import unittest import luigi class NumericalParameterTest(unittest.TestCase): def test_int_min_value_inclusive(self): d = luigi.NumericalParameter(var_type=int, min_value=-3, max_value=7, left_op=le, right_op=lt) self.assertEqual(-3, d.parse(-3)) def test_float_min_value_inclusive(self): d = luigi.NumericalParameter(var_type=float, min_value=-3, max_value=7, left_op=le, right_op=lt) self.assertEqual(-3.0, d.parse(-3)) def test_int_min_value_exclusive(self): d = luigi.NumericalParameter(var_type=int, min_value=-3, max_value=7, left_op=lt, right_op=lt) self.assertRaises(ValueError, lambda: d.parse(-3)) def test_float_min_value_exclusive(self): d = luigi.NumericalParameter(var_type=int, min_value=-3, max_value=7, left_op=lt, right_op=lt) self.assertRaises(ValueError, lambda: d.parse(-3)) def test_int_max_value_inclusive(self): d = luigi.NumericalParameter(var_type=int, min_value=-3, max_value=7, left_op=le, right_op=le) self.assertEqual(7, d.parse(7)) def test_float_max_value_inclusive(self): d = luigi.NumericalParameter(var_type=float, min_value=-3, max_value=7, left_op=le, right_op=le) self.assertEqual(7, d.parse(7)) def test_int_max_value_exclusive(self): d = luigi.NumericalParameter(var_type=int, min_value=-3, max_value=7, left_op=le, right_op=lt) self.assertRaises(ValueError, lambda: d.parse(7)) def test_float_max_value_exclusive(self): d = luigi.NumericalParameter(var_type=float, min_value=-3, max_value=7, left_op=le, right_op=lt) self.assertRaises(ValueError, lambda: d.parse(7)) def test_defaults_start_range(self): d = luigi.NumericalParameter(var_type=int, min_value=-3, max_value=7) self.assertEqual(-3, d.parse(-3)) def test_endpoint_default_exclusive(self): d = luigi.NumericalParameter(var_type=int, min_value=-3, max_value=7) self.assertRaises(ValueError, lambda: d.parse(7)) def test_var_type_parameter_exception(self): self.assertRaises(luigi.parameter.ParameterException, lambda: luigi.NumericalParameter(min_value=-3, max_value=7)) def test_min_value_parameter_exception(self): self.assertRaises(luigi.parameter.ParameterException, lambda: luigi.NumericalParameter(var_type=int, max_value=7)) def test_max_value_parameter_exception(self): self.assertRaises(luigi.parameter.ParameterException, lambda: luigi.NumericalParameter(var_type=int, min_value=-3)) def test_hash_int(self): class Foo(luigi.Task): args = luigi.parameter.NumericalParameter(var_type=int, min_value=-3, max_value=7) p = luigi.parameter.NumericalParameter(var_type=int, min_value=-3, max_value=7) self.assertEqual(hash(Foo(args=-3).args), hash(p.parse("-3"))) def test_hash_float(self): class Foo(luigi.Task): args = luigi.parameter.NumericalParameter(var_type=float, min_value=-3, max_value=7) p = luigi.parameter.NumericalParameter(var_type=float, min_value=-3, max_value=7) self.assertEqual(hash(Foo(args=-3.0).args), hash(p.parse("-3.0"))) def test_int_serialize_parse(self): a = luigi.parameter.NumericalParameter(var_type=int, min_value=-3, max_value=7) b = -3 self.assertEqual(b, a.parse(a.serialize(b))) def test_float_serialize_parse(self): a = luigi.parameter.NumericalParameter(var_type=float, min_value=-3, max_value=7) b = -3.0 self.assertEqual(b, a.parse(a.serialize(b))) ================================================ FILE: test/optional_parameter_test.py ================================================ import warnings import mock from helpers import LuigiTestCase, with_config import luigi class OptionalParameterTest(LuigiTestCase): def actual_test(self, cls, default, expected_value, expected_type, bad_data, **kwargs): class TestConfig(luigi.Config): param = cls(default=default, **kwargs) empty_param = cls(default=default, **kwargs) def run(self): assert self.param == expected_value assert self.empty_param is None # Test parsing empty string (should be None) self.assertIsNone(cls(**kwargs).parse("")) # Test next_in_enumeration always returns None for summary self.assertIsNone(TestConfig.param.next_in_enumeration(expected_value)) self.assertIsNone(TestConfig.param.next_in_enumeration(None)) # Test that warning is raised only with bad type with mock.patch("luigi.parameter.warnings") as warnings: TestConfig() warnings.warn.assert_not_called() if cls != luigi.OptionalChoiceParameter: with mock.patch("luigi.parameter.warnings") as warnings: TestConfig(param=None) warnings.warn.assert_not_called() with mock.patch("luigi.parameter.warnings") as warnings: TestConfig(param=bad_data) if cls == luigi.OptionalBoolParameter: warnings.warn.assert_not_called() else: warnings.warn.assert_called_with( '{} "param" with value "{}" is not of type "{}" or None.'.format(cls.__name__, bad_data, expected_type), luigi.parameter.OptionalParameterTypeWarning, ) # Test with value from config self.assertTrue(luigi.build([TestConfig()], local_scheduler=True)) @with_config({"TestConfig": {"param": "expected value", "empty_param": ""}}) def test_optional_parameter(self): self.actual_test(luigi.OptionalParameter, None, "expected value", "str", 0) self.actual_test(luigi.OptionalParameter, "default value", "expected value", "str", 0) @with_config({"TestConfig": {"param": "10", "empty_param": ""}}) def test_optional_int_parameter(self): self.actual_test(luigi.OptionalIntParameter, None, 10, "int", "bad data") self.actual_test(luigi.OptionalIntParameter, 1, 10, "int", "bad data") @with_config({"TestConfig": {"param": "true", "empty_param": ""}}) def test_optional_bool_parameter(self): self.actual_test(luigi.OptionalBoolParameter, None, True, "bool", "bad data") self.actual_test(luigi.OptionalBoolParameter, False, True, "bool", "bad data") @with_config({"TestConfig": {"param": "10.5", "empty_param": ""}}) def test_optional_float_parameter(self): self.actual_test(luigi.OptionalFloatParameter, None, 10.5, "float", "bad data") self.actual_test(luigi.OptionalFloatParameter, 1.5, 10.5, "float", "bad data") @with_config({"TestConfig": {"param": '{"a": 10}', "empty_param": ""}}) def test_optional_dict_parameter(self): self.actual_test(luigi.OptionalDictParameter, None, {"a": 10}, "FrozenOrderedDict", "bad data") self.actual_test(luigi.OptionalDictParameter, {"a": 1}, {"a": 10}, "FrozenOrderedDict", "bad data") @with_config({"TestConfig": {"param": "[10.5]", "empty_param": ""}}) def test_optional_list_parameter(self): self.actual_test(luigi.OptionalListParameter, None, (10.5,), "tuple", "bad data") self.actual_test(luigi.OptionalListParameter, (1.5,), (10.5,), "tuple", "bad data") @with_config({"TestConfig": {"param": "[10.5]", "empty_param": ""}}) def test_optional_tuple_parameter(self): self.actual_test(luigi.OptionalTupleParameter, None, (10.5,), "tuple", "bad data") self.actual_test(luigi.OptionalTupleParameter, (1.5,), (10.5,), "tuple", "bad data") @with_config({"TestConfig": {"param": "10.5", "empty_param": ""}}) def test_optional_numerical_parameter_float(self): self.actual_test(luigi.OptionalNumericalParameter, None, 10.5, "float", "bad data", var_type=float, min_value=0, max_value=100) self.actual_test(luigi.OptionalNumericalParameter, 1.5, 10.5, "float", "bad data", var_type=float, min_value=0, max_value=100) @with_config({"TestConfig": {"param": "10", "empty_param": ""}}) def test_optional_numerical_parameter_int(self): self.actual_test(luigi.OptionalNumericalParameter, None, 10, "int", "bad data", var_type=int, min_value=0, max_value=100) self.actual_test(luigi.OptionalNumericalParameter, 1, 10, "int", "bad data", var_type=int, min_value=0, max_value=100) @with_config({"TestConfig": {"param": "expected value", "empty_param": ""}}) def test_optional_choice_parameter(self): choices = ["default value", "expected value"] self.actual_test(luigi.OptionalChoiceParameter, None, "expected value", "str", "bad data", choices=choices) self.actual_test(luigi.OptionalChoiceParameter, "default value", "expected value", "str", "bad data", choices=choices) @with_config({"TestConfig": {"param": "1", "empty_param": ""}}) def test_optional_choice_parameter_int(self): choices = [0, 1, 2] self.actual_test(luigi.OptionalChoiceParameter, None, 1, "int", "bad data", var_type=int, choices=choices) self.actual_test(luigi.OptionalChoiceParameter, "default value", 1, "int", "bad data", var_type=int, choices=choices) def test_warning(self): class TestOptionalFloatParameterSingleType(luigi.parameter.OptionalParameter, luigi.FloatParameter): expected_type = float class TestOptionalFloatParameterMultiTypes(luigi.parameter.OptionalParameter, luigi.FloatParameter): expected_type = (int, float) class TestConfig(luigi.Config): param_single = TestOptionalFloatParameterSingleType() param_multi = TestOptionalFloatParameterMultiTypes() with warnings.catch_warnings(record=True) as record: TestConfig(param_single=0.0, param_multi=1.0) assert len(record) == 0 with warnings.catch_warnings(record=True) as record: warnings.filterwarnings( action="ignore", category=Warning, ) warnings.simplefilter( action="always", category=luigi.parameter.OptionalParameterTypeWarning, ) assert luigi.build([TestConfig(param_single="0", param_multi="1")], local_scheduler=True) assert len(record) == 2 assert issubclass(record[0].category, luigi.parameter.OptionalParameterTypeWarning) assert issubclass(record[1].category, luigi.parameter.OptionalParameterTypeWarning) assert str(record[0].message) == ('TestOptionalFloatParameterSingleType "param_single" with value "0" is not of type "float" or None.') assert str(record[1].message) == ('TestOptionalFloatParameterMultiTypes "param_multi" with value "1" is not of any type in ["int", "float"] or None.') ================================================ FILE: test/other_module.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import luigi class OtherModuleTask(luigi.Task): p = luigi.Parameter() def output(self): return luigi.LocalTarget(self.p) def run(self): with self.output().open("w") as f: f.write("Done!") ================================================ FILE: test/parameter_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import datetime import enum from datetime import timedelta import mock import pytest from helpers import LuigiTestCase, RunOnceTask, in_parse, parsing, with_config from worker_test import email_patch import luigi import luigi.date_interval import luigi.interface import luigi.notifications from luigi.mock import MockTarget from luigi.parameter import ParameterException luigi.notifications.DEBUG = True class A(luigi.Task): _visible_in_registry = False # test fixture: invisible to registry to prevent name conflicts p = luigi.IntParameter() class WithDefault(luigi.Task): x = luigi.Parameter(default="xyz") class WithDefaultTrue(luigi.Task): x = luigi.BoolParameter(default=True) class WithDefaultFalse(luigi.Task): x = luigi.BoolParameter(default=False) class Foo(luigi.Task): _visible_in_registry = False # test fixture: invisible to registry to prevent name conflicts bar = luigi.Parameter() p2 = luigi.IntParameter() not_a_param = "lol" class Baz(luigi.Task): bool = luigi.BoolParameter() bool_true = luigi.BoolParameter(default=True) bool_explicit = luigi.BoolParameter(parsing=luigi.BoolParameter.EXPLICIT_PARSING) def run(self): Baz._val = self.bool Baz._val_true = self.bool_true Baz._val_explicit = self.bool_explicit class ListFoo(luigi.Task): my_list = luigi.ListParameter() def run(self): ListFoo._val = self.my_list class TupleFoo(luigi.Task): my_tuple = luigi.TupleParameter() def run(self): TupleFoo._val = self.my_tuple class ForgotParam(luigi.Task): param = luigi.Parameter() def run(self): pass class ForgotParamDep(luigi.Task): def requires(self): return ForgotParam() def run(self): pass class BananaDep(luigi.Task): x = luigi.Parameter() y = luigi.Parameter(default="def") def output(self): return MockTarget("banana-dep-%s-%s" % (self.x, self.y)) def run(self): self.output().open("w").close() class Banana(luigi.Task): x = luigi.Parameter() y = luigi.Parameter() style = luigi.Parameter(default=None) def requires(self): if self.style is None: return BananaDep() # will fail elif self.style == "x-arg": return BananaDep(self.x) elif self.style == "y-kwarg": return BananaDep(y=self.y) elif self.style == "x-arg-y-arg": return BananaDep(self.x, self.y) else: raise Exception("unknown style") def output(self): return MockTarget("banana-%s-%s" % (self.x, self.y)) def run(self): self.output().open("w").close() class MyConfig(luigi.Config): mc_p = luigi.IntParameter() mc_q = luigi.IntParameter(default=73) class MyConfigWithoutSection(luigi.Config): use_cmdline_section = False mc_r = luigi.IntParameter() mc_s = luigi.IntParameter(default=99) class NoopTask(luigi.Task): pass class MyEnum(enum.Enum): A = 1 C = 3 def _value(parameter): """ A hackish way to get the "value" of a parameter. Previously Parameter exposed ``param_obj._value``. This is replacement for that so I don't need to rewrite all test cases. """ class DummyLuigiTask(luigi.Task): param = parameter return DummyLuigiTask().param class ParameterTest(LuigiTestCase): def test_default_param(self): self.assertEqual(WithDefault().x, "xyz") def test_missing_param(self): def create_a(): return A() self.assertRaises(luigi.parameter.MissingParameterException, create_a) def test_unknown_param(self): def create_a(): return A(p=5, q=4) self.assertRaises(luigi.parameter.UnknownParameterException, create_a) def test_unknown_param_2(self): def create_a(): return A(1, 2, 3) self.assertRaises(luigi.parameter.UnknownParameterException, create_a) def test_duplicated_param(self): def create_a(): return A(5, p=7) self.assertRaises(luigi.parameter.DuplicateParameterException, create_a) def test_parameter_registration(self): self.assertEqual(len(Foo.get_params()), 2) def test_task_creation(self): f = Foo("barval", p2=5) self.assertEqual(len(f.get_params()), 2) self.assertEqual(f.bar, "barval") self.assertEqual(f.p2, 5) self.assertEqual(f.not_a_param, "lol") def test_bool_parsing(self): self.run_locally(["Baz"]) self.assertFalse(Baz._val) self.assertTrue(Baz._val_true) self.assertFalse(Baz._val_explicit) self.run_locally(["Baz", "--bool", "--bool-true"]) self.assertTrue(Baz._val) self.assertTrue(Baz._val_true) self.run_locally(["Baz", "--bool-explicit", "true"]) self.assertTrue(Baz._val_explicit) self.run_locally(["Baz", "--bool-explicit", "false"]) self.assertFalse(Baz._val_explicit) def test_bool_default(self): self.assertTrue(WithDefaultTrue().x) self.assertFalse(WithDefaultFalse().x) def test_bool_coerce(self): self.assertTrue(WithDefaultTrue(x="true").x) self.assertFalse(WithDefaultTrue(x="false").x) def test_bool_no_coerce_none(self): self.assertIsNone(WithDefaultTrue(x=None).x) def test_forgot_param(self): self.assertRaises( luigi.parameter.MissingParameterException, self.run_locally, ["ForgotParam"], ) @email_patch def test_forgot_param_in_dep(self, emails): # A programmatic missing parameter will cause an error email to be sent self.run_locally(["ForgotParamDep"]) self.assertNotEqual(emails, []) def test_default_param_cmdline(self): self.assertEqual(WithDefault().x, "xyz") def test_default_param_cmdline_2(self): self.assertEqual(WithDefault().x, "xyz") def test_insignificant_parameter(self): class InsignificantParameterTask(luigi.Task): foo = luigi.Parameter(significant=False, default="foo_default") bar = luigi.Parameter() t1 = InsignificantParameterTask(foo="x", bar="y") self.assertEqual(str(t1), "InsignificantParameterTask(bar=y)") t2 = InsignificantParameterTask("u", "z") self.assertEqual(t2.foo, "u") self.assertEqual(t2.bar, "z") self.assertEqual(str(t2), "InsignificantParameterTask(bar=z)") def test_local_significant_param(self): """Obviously, if anything should be positional, so should local significant parameters""" class MyTask(luigi.Task): # This could typically be "--label-company=disney" x = luigi.Parameter(significant=True) MyTask("arg") self.assertRaises(luigi.parameter.MissingParameterException, lambda: MyTask()) def test_local_insignificant_param(self): """Ensure we have the same behavior as in before a78338c""" class MyTask(luigi.Task): # This could typically be "--num-threads=True" x = luigi.Parameter(significant=False) MyTask("arg") self.assertRaises(luigi.parameter.MissingParameterException, lambda: MyTask()) def test_nonpositional_param(self): """Ensure we have the same behavior as in before a78338c""" class MyTask(luigi.Task): # This could typically be "--num-threads=10" x = luigi.Parameter(significant=False, positional=False) MyTask(x="arg") self.assertRaises(luigi.parameter.UnknownParameterException, lambda: MyTask("arg")) def test_enum_param_valid(self): p = luigi.parameter.EnumParameter(enum=MyEnum) self.assertEqual(MyEnum.A, p.parse("A")) def test_enum_param_invalid(self): p = luigi.parameter.EnumParameter(enum=MyEnum) self.assertRaises(ValueError, lambda: p.parse("B")) def test_enum_param_missing(self): self.assertRaises(ParameterException, lambda: luigi.parameter.EnumParameter()) def test_enum_list_param_valid(self): p = luigi.parameter.EnumListParameter(enum=MyEnum) self.assertEqual((), p.parse("")) self.assertEqual((MyEnum.A,), p.parse("A")) self.assertEqual((MyEnum.A, MyEnum.C), p.parse("A,C")) def test_enum_list_param_invalid(self): p = luigi.parameter.EnumListParameter(enum=MyEnum) self.assertRaises(ValueError, lambda: p.parse("A,B")) def test_enum_list_param_missing(self): self.assertRaises(ParameterException, lambda: luigi.parameter.EnumListParameter()) def test_choice_list_param_valid(self): p = luigi.parameter.ChoiceListParameter(choices=["1", "2", "3"]) self.assertEqual((), p.parse("")) self.assertEqual(("1",), p.parse("1")) self.assertEqual(("1", "3"), p.parse("1,3")) def test_choice_list_param_invalid(self): p = luigi.parameter.ChoiceListParameter(choices=["1", "2", "3"]) self.assertRaises(ValueError, lambda: p.parse("1,4")) def test_invalid_choice_type(self): self.assertRaises( AssertionError, lambda: luigi.ChoiceListParameter(var_type=int, choices=[1, 2, "3"]), ) def test_choice_list_param_missing(self): self.assertRaises(ParameterException, lambda: luigi.parameter.ChoiceListParameter()) def test_tuple_serialize_parse(self): a = luigi.TupleParameter() b_tuple = ((1, 2), (3, 4)) self.assertEqual(b_tuple, a.parse(a.serialize(b_tuple))) def test_parse_list_without_batch_method(self): param = luigi.Parameter() for xs in [], ["x"], ["x", "y"]: self.assertRaises(NotImplementedError, param._parse_list, xs) def test_parse_empty_list_raises_value_error(self): for batch_method in (max, min, tuple, ",".join): param = luigi.Parameter(batch_method=batch_method) self.assertRaises(ValueError, param._parse_list, []) def test_parse_int_list_max(self): param = luigi.IntParameter(batch_method=max) self.assertEqual(17, param._parse_list(["7", "17", "5"])) def test_parse_string_list_max(self): param = luigi.Parameter(batch_method=max) self.assertEqual("7", param._parse_list(["7", "17", "5"])) def test_parse_list_as_tuple(self): param = luigi.IntParameter(batch_method=tuple) self.assertEqual((7, 17, 5), param._parse_list(["7", "17", "5"])) @mock.patch("luigi.parameter.warnings") def test_warn_on_default_none(self, warnings): class TestConfig(luigi.Config): param = luigi.Parameter(default=None) TestConfig() warnings.warn.assert_called_once_with('Parameter "param" with value "None" is not of type string.') @mock.patch("luigi.parameter.warnings") def test_no_warn_on_string(self, warnings): class TestConfig(luigi.Config): param = luigi.Parameter(default=None) TestConfig(param="str") warnings.warn.assert_not_called() def test_no_warn_on_none_in_optional(self): class TestConfig(luigi.Config): param = luigi.OptionalParameter(default=None) with mock.patch("luigi.parameter.warnings") as warnings: TestConfig() warnings.warn.assert_not_called() with mock.patch("luigi.parameter.warnings") as warnings: TestConfig(param=None) warnings.warn.assert_not_called() with mock.patch("luigi.parameter.warnings") as warnings: TestConfig(param="") warnings.warn.assert_not_called() @mock.patch("luigi.parameter.warnings") def test_no_warn_on_string_in_optional(self, warnings): class TestConfig(luigi.Config): param = luigi.OptionalParameter(default=None) TestConfig(param="value") warnings.warn.assert_not_called() @mock.patch("luigi.parameter.warnings") def test_warn_on_bad_type_in_optional(self, warnings): class TestConfig(luigi.Config): param = luigi.OptionalParameter() TestConfig(param=1) warnings.warn.assert_called_once_with( 'OptionalParameter "param" with value "1" is not of type "str" or None.', luigi.parameter.OptionalParameterTypeWarning ) def test_optional_parameter_parse_none(self): self.assertIsNone(luigi.OptionalParameter().parse("")) def test_optional_parameter_parse_string(self): self.assertEqual("test", luigi.OptionalParameter().parse("test")) def test_optional_parameter_serialize_none(self): self.assertEqual("", luigi.OptionalParameter().serialize(None)) def test_optional_parameter_serialize_string(self): self.assertEqual("test", luigi.OptionalParameter().serialize("test")) class TestParametersHashability(LuigiTestCase): def test_date(self): class Foo(luigi.Task): args = luigi.parameter.DateParameter() p = luigi.parameter.DateParameter() self.assertEqual(hash(Foo(args=datetime.date(2000, 1, 1)).args), hash(p.parse("2000-1-1"))) def test_dateminute(self): class Foo(luigi.Task): args = luigi.parameter.DateMinuteParameter() p = luigi.parameter.DateMinuteParameter() self.assertEqual(hash(Foo(args=datetime.datetime(2000, 1, 1, 12, 0)).args), hash(p.parse("2000-1-1T1200"))) def test_dateinterval(self): class Foo(luigi.Task): args = luigi.parameter.DateIntervalParameter() p = luigi.parameter.DateIntervalParameter() di = luigi.date_interval.Custom(datetime.date(2000, 1, 1), datetime.date(2000, 2, 12)) self.assertEqual(hash(Foo(args=di).args), hash(p.parse("2000-01-01-2000-02-12"))) def test_timedelta(self): class Foo(luigi.Task): args = luigi.parameter.TimeDeltaParameter() p = luigi.parameter.TimeDeltaParameter() self.assertEqual(hash(Foo(args=datetime.timedelta(days=2, hours=3, minutes=2)).args), hash(p.parse("P2DT3H2M"))) def test_boolean(self): class Foo(luigi.Task): args = luigi.parameter.BoolParameter() p = luigi.parameter.BoolParameter() self.assertEqual(hash(Foo(args=True).args), hash(p.parse("true"))) self.assertEqual(hash(Foo(args=False).args), hash(p.parse("false"))) def test_int(self): class Foo(luigi.Task): args = luigi.parameter.IntParameter() p = luigi.parameter.IntParameter() self.assertEqual(hash(Foo(args=1).args), hash(p.parse("1"))) def test_float(self): class Foo(luigi.Task): args = luigi.parameter.FloatParameter() p = luigi.parameter.FloatParameter() self.assertEqual(hash(Foo(args=1.0).args), hash(p.parse("1"))) def test_enum(self): class Foo(luigi.Task): args = luigi.parameter.EnumParameter(enum=MyEnum) p = luigi.parameter.EnumParameter(enum=MyEnum) self.assertEqual(hash(Foo(args=MyEnum.A).args), hash(p.parse("A"))) def test_enum_list(self): class Foo(luigi.Task): args = luigi.parameter.EnumListParameter(enum=MyEnum) p = luigi.parameter.EnumListParameter(enum=MyEnum) self.assertEqual(hash(Foo(args=(MyEnum.A, MyEnum.C)).args), hash(p.parse("A,C"))) class FooWithDefault(luigi.Task): args = luigi.parameter.EnumListParameter(enum=MyEnum, default=[MyEnum.C]) self.assertEqual(FooWithDefault().args, p.parse("C")) def test_choice_list(self): class Foo(luigi.Task): args = luigi.ChoiceListParameter(var_type=str, choices=["1", "2", "3"]) p = luigi.ChoiceListParameter(var_type=str, choices=["3", "2", "1"]) self.assertEqual(hash(Foo(args=("3",)).args), hash(p.parse("3"))) def test_dict(self): class Foo(luigi.Task): args = luigi.parameter.DictParameter() p = luigi.parameter.DictParameter() self.assertEqual(hash(Foo(args=dict(foo=1, bar="hello")).args), hash(p.parse('{"foo":1,"bar":"hello"}'))) def test_list(self): class Foo(luigi.Task): args = luigi.parameter.ListParameter() p = luigi.parameter.ListParameter() self.assertEqual(hash(Foo(args=[1, "hello"]).args), hash(p.normalize(p.parse('[1,"hello"]')))) def test_list_param_with_default_none_in_dynamic_req_task(self): class TaskWithDefaultNoneParameter(RunOnceTask): args = luigi.parameter.ListParameter(default=None) class DynamicTaskCallsDefaultNoneParameter(RunOnceTask): def run(self): yield [TaskWithDefaultNoneParameter()] self.comp = True self.assertTrue(self.run_locally(["DynamicTaskCallsDefaultNoneParameter"])) def test_list_dict(self): class Foo(luigi.Task): args = luigi.parameter.ListParameter() p = luigi.parameter.ListParameter() self.assertEqual(hash(Foo(args=[{"foo": "bar"}, {"doge": "wow"}]).args), hash(p.normalize(p.parse('[{"foo": "bar"}, {"doge": "wow"}]')))) def test_list_nested(self): class Foo(luigi.Task): args = luigi.parameter.ListParameter() p = luigi.parameter.ListParameter() self.assertEqual(hash(Foo(args=[["foo", "bar"], ["doge", "wow"]]).args), hash(p.normalize(p.parse('[["foo", "bar"], ["doge", "wow"]]')))) def test_tuple(self): class Foo(luigi.Task): args = luigi.parameter.TupleParameter() p = luigi.parameter.TupleParameter() self.assertEqual(hash(Foo(args=(1, "hello")).args), hash(p.parse('(1,"hello")'))) def test_tuple_dict(self): class Foo(luigi.Task): args = luigi.parameter.TupleParameter() p = luigi.parameter.TupleParameter() self.assertEqual(hash(Foo(args=({"foo": "bar"}, {"doge": "wow"})).args), hash(p.normalize(p.parse('({"foo": "bar"}, {"doge": "wow"})')))) def test_tuple_nested(self): class Foo(luigi.Task): args = luigi.parameter.TupleParameter() p = luigi.parameter.TupleParameter() self.assertEqual(hash(Foo(args=(("foo", "bar"), ("doge", "wow"))).args), hash(p.normalize(p.parse('(("foo", "bar"), ("doge", "wow"))')))) def test_tuple_string_with_json(self): class Foo(luigi.Task): args = luigi.parameter.TupleParameter() p = luigi.parameter.TupleParameter() self.assertEqual(hash(Foo(args=("foo", "bar")).args), hash(p.normalize(p.parse('["foo", "bar"]')))) def test_tuple_invalid_string(self): param = luigi.TupleParameter() self.assertRaises(ValueError, lambda: param.parse('("abcd")')) def test_tuple_invalid_string_in_tuple(self): param = luigi.TupleParameter() self.assertRaises(ValueError, lambda: param.parse('(("abcd"))')) def test_parse_invalid_format(self): param = luigi.TupleParameter() self.assertRaises(SyntaxError, lambda: param.parse("((1,2),(3,4")) def test_task(self): class Bar(luigi.Task): pass class Foo(luigi.Task): args = luigi.parameter.TaskParameter() p = luigi.parameter.TaskParameter() self.assertEqual(hash(Foo(args=Bar).args), hash(p.parse("Bar"))) class TestNewStyleGlobalParameters(LuigiTestCase): def setUp(self): super(TestNewStyleGlobalParameters, self).setUp() MockTarget.fs.clear() def expect_keys(self, expected): self.assertEqual(set(MockTarget.fs.get_all_data().keys()), set(expected)) def test_x_arg(self): self.run_locally(["Banana", "--x", "foo", "--y", "bar", "--style", "x-arg"]) self.expect_keys(["banana-foo-bar", "banana-dep-foo-def"]) def test_x_arg_override(self): self.run_locally(["Banana", "--x", "foo", "--y", "bar", "--style", "x-arg", "--BananaDep-y", "xyz"]) self.expect_keys(["banana-foo-bar", "banana-dep-foo-xyz"]) def test_x_arg_override_stupid(self): self.run_locally(["Banana", "--x", "foo", "--y", "bar", "--style", "x-arg", "--BananaDep-x", "blabla"]) self.expect_keys(["banana-foo-bar", "banana-dep-foo-def"]) def test_x_arg_y_arg(self): self.run_locally(["Banana", "--x", "foo", "--y", "bar", "--style", "x-arg-y-arg"]) self.expect_keys(["banana-foo-bar", "banana-dep-foo-bar"]) def test_x_arg_y_arg_override(self): self.run_locally(["Banana", "--x", "foo", "--y", "bar", "--style", "x-arg-y-arg", "--BananaDep-y", "xyz"]) self.expect_keys(["banana-foo-bar", "banana-dep-foo-bar"]) def test_x_arg_y_arg_override_all(self): self.run_locally(["Banana", "--x", "foo", "--y", "bar", "--style", "x-arg-y-arg", "--BananaDep-y", "xyz", "--BananaDep-x", "blabla"]) self.expect_keys(["banana-foo-bar", "banana-dep-foo-bar"]) def test_y_arg_override(self): self.run_locally(["Banana", "--x", "foo", "--y", "bar", "--style", "y-kwarg", "--BananaDep-x", "xyz"]) self.expect_keys(["banana-foo-bar", "banana-dep-xyz-bar"]) def test_y_arg_override_both(self): self.run_locally(["Banana", "--x", "foo", "--y", "bar", "--style", "y-kwarg", "--BananaDep-x", "xyz", "--BananaDep-y", "blah"]) self.expect_keys(["banana-foo-bar", "banana-dep-xyz-bar"]) def test_y_arg_override_banana(self): self.run_locally(["Banana", "--y", "bar", "--style", "y-kwarg", "--BananaDep-x", "xyz", "--Banana-x", "baz"]) self.expect_keys(["banana-baz-bar", "banana-dep-xyz-bar"]) class TestRemoveGlobalParameters(LuigiTestCase): def run_and_check(self, args): run_exit_status = self.run_locally(args) self.assertTrue(run_exit_status) return run_exit_status @parsing(["--MyConfig-mc-p", "99", "--mc-r", "55", "NoopTask"]) def test_use_config_class_1(self): self.assertEqual(MyConfig().mc_p, 99) self.assertEqual(MyConfig().mc_q, 73) self.assertEqual(MyConfigWithoutSection().mc_r, 55) self.assertEqual(MyConfigWithoutSection().mc_s, 99) @parsing(["NoopTask", "--MyConfig-mc-p", "99", "--mc-r", "55"]) def test_use_config_class_2(self): self.assertEqual(MyConfig().mc_p, 99) self.assertEqual(MyConfig().mc_q, 73) self.assertEqual(MyConfigWithoutSection().mc_r, 55) self.assertEqual(MyConfigWithoutSection().mc_s, 99) @parsing(["--MyConfig-mc-p", "99", "--mc-r", "55", "NoopTask", "--mc-s", "123", "--MyConfig-mc-q", "42"]) def test_use_config_class_more_args(self): self.assertEqual(MyConfig().mc_p, 99) self.assertEqual(MyConfig().mc_q, 42) self.assertEqual(MyConfigWithoutSection().mc_r, 55) self.assertEqual(MyConfigWithoutSection().mc_s, 123) @with_config({"MyConfig": {"mc_p": "666", "mc_q": "777"}}) @parsing(["--mc-r", "555", "NoopTask"]) def test_use_config_class_with_configuration(self): self.assertEqual(MyConfig().mc_p, 666) self.assertEqual(MyConfig().mc_q, 777) self.assertEqual(MyConfigWithoutSection().mc_r, 555) self.assertEqual(MyConfigWithoutSection().mc_s, 99) @with_config({"MyConfigWithoutSection": {"mc_r": "999", "mc_s": "888"}}) @parsing(["NoopTask", "--MyConfig-mc-p", "222", "--mc-r", "555"]) def test_use_config_class_with_configuration_2(self): self.assertEqual(MyConfig().mc_p, 222) self.assertEqual(MyConfig().mc_q, 73) self.assertEqual(MyConfigWithoutSection().mc_r, 555) self.assertEqual(MyConfigWithoutSection().mc_s, 888) @with_config({"MyConfig": {"mc_p": "555", "mc-p": "666", "mc-q": "777"}}) def test_configuration_style(self): self.assertEqual(MyConfig().mc_p, 555) self.assertEqual(MyConfig().mc_q, 777) def test_misc_1(self): class Dogs(luigi.Config): n_dogs = luigi.IntParameter() class CatsWithoutSection(luigi.Config): use_cmdline_section = False n_cats = luigi.IntParameter() with luigi.cmdline_parser.CmdlineParser.global_instance(["--n-cats", "123", "--Dogs-n-dogs", "456", "WithDefault"], allow_override=True): self.assertEqual(Dogs().n_dogs, 456) self.assertEqual(CatsWithoutSection().n_cats, 123) with luigi.cmdline_parser.CmdlineParser.global_instance(["WithDefault", "--n-cats", "321", "--Dogs-n-dogs", "654"], allow_override=True): self.assertEqual(Dogs().n_dogs, 654) self.assertEqual(CatsWithoutSection().n_cats, 321) def test_global_significant_param_warning(self): """We don't want any kind of global param to be positional""" with self.assertWarnsRegex(DeprecationWarning, "is_global support is removed. Assuming positional=False"): class MyTask(luigi.Task): # This could typically be called "--test-dry-run" x_g1 = luigi.Parameter(default="y", is_global=True, significant=True) self.assertRaises(luigi.parameter.UnknownParameterException, lambda: MyTask("arg")) def test_global_insignificant_param_warning(self): """We don't want any kind of global param to be positional""" with self.assertWarnsRegex(DeprecationWarning, "is_global support is removed. Assuming positional=False"): class MyTask(luigi.Task): # This could typically be "--yarn-pool=development" x_g2 = luigi.Parameter(default="y", is_global=True, significant=False) self.assertRaises(luigi.parameter.UnknownParameterException, lambda: MyTask("arg")) class TestParamWithDefaultFromConfig(LuigiTestCase): def testNoSection(self): self.assertRaises(ParameterException, lambda: _value(luigi.Parameter(config_path=dict(section="foo", name="bar")))) @with_config({"foo": {}}) def testNoValue(self): self.assertRaises(ParameterException, lambda: _value(luigi.Parameter(config_path=dict(section="foo", name="bar")))) @with_config({"foo": {"bar": "baz"}}) def testDefault(self): class LocalA(luigi.Task): p = luigi.Parameter(config_path=dict(section="foo", name="bar")) self.assertEqual("baz", LocalA().p) self.assertEqual("boo", LocalA(p="boo").p) @with_config({"foo": {"bar": "2001-02-03T04"}}) def testDateHour(self): p = luigi.DateHourParameter(config_path=dict(section="foo", name="bar")) self.assertEqual(datetime.datetime(2001, 2, 3, 4, 0, 0), _value(p)) @with_config({"foo": {"bar": "2001-02-03T05"}}) def testDateHourWithInterval(self): p = luigi.DateHourParameter(config_path=dict(section="foo", name="bar"), interval=2) self.assertEqual(datetime.datetime(2001, 2, 3, 4, 0, 0), _value(p)) @with_config({"foo": {"bar": "2001-02-03T0430"}}) def testDateMinute(self): p = luigi.DateMinuteParameter(config_path=dict(section="foo", name="bar")) self.assertEqual(datetime.datetime(2001, 2, 3, 4, 30, 0), _value(p)) @with_config({"foo": {"bar": "2001-02-03T0431"}}) def testDateWithMinuteInterval(self): p = luigi.DateMinuteParameter(config_path=dict(section="foo", name="bar"), interval=2) self.assertEqual(datetime.datetime(2001, 2, 3, 4, 30, 0), _value(p)) @with_config({"foo": {"bar": "2001-02-03T04H30"}}) def testDateMinuteDeprecated(self): p = luigi.DateMinuteParameter(config_path=dict(section="foo", name="bar")) with self.assertWarnsRegex(DeprecationWarning, 'Using "H" between hours and minutes is deprecated, omit it instead.'): self.assertEqual(datetime.datetime(2001, 2, 3, 4, 30, 0), _value(p)) @with_config({"foo": {"bar": "2001-02-03T040506"}}) def testDateSecond(self): p = luigi.DateSecondParameter(config_path=dict(section="foo", name="bar")) self.assertEqual(datetime.datetime(2001, 2, 3, 4, 5, 6), _value(p)) @with_config({"foo": {"bar": "2001-02-03T040507"}}) def testDateSecondWithInterval(self): p = luigi.DateSecondParameter(config_path=dict(section="foo", name="bar"), interval=2) self.assertEqual(datetime.datetime(2001, 2, 3, 4, 5, 6), _value(p)) @with_config({"foo": {"bar": "2001-02-03"}}) def testDate(self): p = luigi.DateParameter(config_path=dict(section="foo", name="bar")) self.assertEqual(datetime.date(2001, 2, 3), _value(p)) @with_config({"foo": {"bar": "2001-02-03"}}) def testDateWithInterval(self): p = luigi.DateParameter(config_path=dict(section="foo", name="bar"), interval=3, start=datetime.date(2001, 2, 1)) self.assertEqual(datetime.date(2001, 2, 1), _value(p)) @with_config({"foo": {"bar": "2015-07"}}) def testMonthParameter(self): p = luigi.MonthParameter(config_path=dict(section="foo", name="bar")) self.assertEqual(datetime.date(2015, 7, 1), _value(p)) @with_config({"foo": {"bar": "2015-07"}}) def testMonthWithIntervalParameter(self): p = luigi.MonthParameter(config_path=dict(section="foo", name="bar"), interval=13, start=datetime.date(2014, 1, 1)) self.assertEqual(datetime.date(2015, 2, 1), _value(p)) @with_config({"foo": {"bar": "2015"}}) def testYearParameter(self): p = luigi.YearParameter(config_path=dict(section="foo", name="bar")) self.assertEqual(datetime.date(2015, 1, 1), _value(p)) @with_config({"foo": {"bar": "2015"}}) def testYearWithIntervalParameter(self): p = luigi.YearParameter(config_path=dict(section="foo", name="bar"), start=datetime.date(2011, 1, 1), interval=5) self.assertEqual(datetime.date(2011, 1, 1), _value(p)) @with_config({"foo": {"bar": "123"}}) def testInt(self): p = luigi.IntParameter(config_path=dict(section="foo", name="bar")) self.assertEqual(123, _value(p)) @with_config({"foo": {"bar": "true"}}) def testBool(self): p = luigi.BoolParameter(config_path=dict(section="foo", name="bar")) self.assertEqual(True, _value(p)) @with_config({"foo": {"bar": "false"}}) def testBoolConfigOutranksDefault(self): p = luigi.BoolParameter(default=True, config_path=dict(section="foo", name="bar")) self.assertEqual(False, _value(p)) @with_config({"foo": {"bar": "2001-02-03-2001-02-28"}}) def testDateInterval(self): p = luigi.DateIntervalParameter(config_path=dict(section="foo", name="bar")) expected = luigi.date_interval.Custom.parse("2001-02-03-2001-02-28") self.assertEqual(expected, _value(p)) @with_config({"foo": {"bar": "0 seconds"}}) def testTimeDeltaNoSeconds(self): p = luigi.TimeDeltaParameter(config_path=dict(section="foo", name="bar")) self.assertEqual(timedelta(seconds=0), _value(p)) @with_config({"foo": {"bar": "0 d"}}) def testTimeDeltaNoDays(self): p = luigi.TimeDeltaParameter(config_path=dict(section="foo", name="bar")) self.assertEqual(timedelta(days=0), _value(p)) @with_config({"foo": {"bar": "1 day"}}) def testTimeDelta(self): p = luigi.TimeDeltaParameter(config_path=dict(section="foo", name="bar")) self.assertEqual(timedelta(days=1), _value(p)) @with_config({"foo": {"bar": "2 seconds"}}) def testTimeDeltaPlural(self): p = luigi.TimeDeltaParameter(config_path=dict(section="foo", name="bar")) self.assertEqual(timedelta(seconds=2), _value(p)) @with_config({"foo": {"bar": "3w 4h 5m"}}) def testTimeDeltaMultiple(self): p = luigi.TimeDeltaParameter(config_path=dict(section="foo", name="bar")) self.assertEqual(timedelta(weeks=3, hours=4, minutes=5), _value(p)) @with_config({"foo": {"bar": "P4DT12H30M5S"}}) def testTimeDelta8601(self): p = luigi.TimeDeltaParameter(config_path=dict(section="foo", name="bar")) self.assertEqual(timedelta(days=4, hours=12, minutes=30, seconds=5), _value(p)) @with_config({"foo": {"bar": "P5D"}}) def testTimeDelta8601NoTimeComponent(self): p = luigi.TimeDeltaParameter(config_path=dict(section="foo", name="bar")) self.assertEqual(timedelta(days=5), _value(p)) @with_config({"foo": {"bar": "P5W"}}) def testTimeDelta8601Weeks(self): p = luigi.TimeDeltaParameter(config_path=dict(section="foo", name="bar")) self.assertEqual(timedelta(weeks=5), _value(p)) @mock.patch("luigi.parameter.ParameterException") @with_config({"foo": {"bar": "P3Y6M4DT12H30M5S"}}) def testTimeDelta8601YearMonthNotSupported(self, exc): def f(): return _value(luigi.TimeDeltaParameter(config_path=dict(section="foo", name="bar"))) self.assertRaises(ValueError, f) # ISO 8601 durations with years or months are not supported exc.assert_called_once_with("Invalid time delta - could not parse P3Y6M4DT12H30M5S") @with_config({"foo": {"bar": "PT6M"}}) def testTimeDelta8601MAfterT(self): p = luigi.TimeDeltaParameter(config_path=dict(section="foo", name="bar")) self.assertEqual(timedelta(minutes=6), _value(p)) @mock.patch("luigi.parameter.ParameterException") @with_config({"foo": {"bar": "P6M"}}) def testTimeDelta8601MBeforeT(self, exc): def f(): return _value(luigi.TimeDeltaParameter(config_path=dict(section="foo", name="bar"))) self.assertRaises(ValueError, f) # ISO 8601 durations with months are not supported exc.assert_called_once_with("Invalid time delta - could not parse P6M") @with_config({"foo": {"bar": "12.34"}}) def testTimeDeltaFloat(self): p = luigi.TimeDeltaParameter(config_path=dict(section="foo", name="bar")) self.assertEqual(timedelta(seconds=12.34), _value(p)) @with_config({"foo": {"bar": "56789"}}) def testTimeDeltaInt(self): p = luigi.TimeDeltaParameter(config_path=dict(section="foo", name="bar")) self.assertEqual(timedelta(seconds=56789), _value(p)) def testHasDefaultNoSection(self): self.assertRaises(luigi.parameter.MissingParameterException, lambda: _value(luigi.Parameter(config_path=dict(section="foo", name="bar")))) @with_config({"foo": {}}) def testHasDefaultNoValue(self): self.assertRaises(luigi.parameter.MissingParameterException, lambda: _value(luigi.Parameter(config_path=dict(section="foo", name="bar")))) @with_config({"foo": {"bar": "baz"}}) def testHasDefaultWithBoth(self): self.assertTrue(_value(luigi.Parameter(config_path=dict(section="foo", name="bar")))) @with_config({"foo": {"bar": "baz"}}) def testWithDefault(self): p = luigi.Parameter(config_path=dict(section="foo", name="bar"), default="blah") self.assertEqual("baz", _value(p)) # config overrides default def testWithDefaultAndMissing(self): p = luigi.Parameter(config_path=dict(section="foo", name="bar"), default="blah") self.assertEqual("blah", _value(p)) @with_config({"LocalA": {"p": "p_default"}}) def testDefaultFromTaskName(self): class LocalA(luigi.Task): p = luigi.Parameter() self.assertEqual("p_default", LocalA().p) self.assertEqual("boo", LocalA(p="boo").p) @with_config({"LocalA": {"p": "999"}}) def testDefaultFromTaskNameInt(self): class LocalA(luigi.Task): p = luigi.IntParameter() self.assertEqual(999, LocalA().p) self.assertEqual(777, LocalA(p=777).p) @with_config({"LocalA": {"p": "p_default"}, "foo": {"bar": "baz"}}) def testDefaultFromConfigWithTaskNameToo(self): class LocalA(luigi.Task): p = luigi.Parameter(config_path=dict(section="foo", name="bar")) self.assertEqual("p_default", LocalA().p) self.assertEqual("boo", LocalA(p="boo").p) @with_config({"LocalA": {"p": "p_default_2"}}) def testDefaultFromTaskNameWithDefault(self): class LocalA(luigi.Task): p = luigi.Parameter(default="banana") self.assertEqual("p_default_2", LocalA().p) self.assertEqual("boo_2", LocalA(p="boo_2").p) @with_config({"MyClass": {"p_wohoo": "p_default_3"}}) def testWithLongParameterName(self): class MyClass(luigi.Task): p_wohoo = luigi.Parameter(default="banana") self.assertEqual("p_default_3", MyClass().p_wohoo) self.assertEqual("boo_2", MyClass(p_wohoo="boo_2").p_wohoo) @with_config({"RangeDaily": {"days_back": "123"}}) def testSettingOtherMember(self): class LocalA(luigi.Task): pass self.assertEqual(123, luigi.tools.range.RangeDaily(of=LocalA).days_back) self.assertEqual(70, luigi.tools.range.RangeDaily(of=LocalA, days_back=70).days_back) @with_config({"MyClass": {"p_not_global": "123"}}) def testCommandLineWithDefault(self): """ Verify that we also read from the config when we build tasks from the command line parsers. """ class MyClass(luigi.Task): p_not_global = luigi.Parameter(default="banana") def complete(self): import sys luigi.configuration.get_config().write(sys.stdout) if self.p_not_global != "123": raise ValueError("The parameter didn't get set!!") return True def run(self): pass self.assertTrue(self.run_locally(["MyClass"])) self.assertFalse(self.run_locally(["MyClass", "--p-not-global", "124"])) self.assertFalse(self.run_locally(["MyClass", "--MyClass-p-not-global", "124"])) @with_config({"MyClass2": {"p_not_global_no_default": "123"}}) def testCommandLineNoDefault(self): """ Verify that we also read from the config when we build tasks from the command line parsers. """ class MyClass2(luigi.Task): """TODO: Make luigi clean it's register for tests. Hate this 2 dance.""" p_not_global_no_default = luigi.Parameter() def complete(self): import sys luigi.configuration.get_config().write(sys.stdout) luigi.configuration.get_config().write(sys.stdout) if self.p_not_global_no_default != "123": raise ValueError("The parameter didn't get set!!") return True def run(self): pass self.assertTrue(self.run_locally(["MyClass2"])) self.assertFalse(self.run_locally(["MyClass2", "--p-not-global-no-default", "124"])) self.assertFalse(self.run_locally(["MyClass2", "--MyClass2-p-not-global-no-default", "124"])) @with_config({"mynamespace.A": {"p": "999"}}) def testWithNamespaceConfig(self): class A(luigi.Task): task_namespace = "mynamespace" p = luigi.IntParameter() self.assertEqual(999, A().p) self.assertEqual(777, A(p=777).p) def testWithNamespaceCli(self): class A(luigi.Task): task_namespace = "mynamespace" p = luigi.IntParameter(default=100) expected = luigi.IntParameter() def complete(self): if self.p != self.expected: raise ValueError return True self.assertTrue(self.run_locally_split("mynamespace.A --expected 100")) # TODO(arash): Why is `--p 200` hanging with multiprocessing stuff? # self.assertTrue(self.run_locally_split('mynamespace.A --p 200 --expected 200')) self.assertTrue(self.run_locally_split("mynamespace.A --mynamespace.A-p 200 --expected 200")) # --A-p is unrecognized since module-level A is _visible_in_registry=False (no CLI flag) self.assertRaises(SystemExit, self.run_locally_split, "mynamespace.A --A-p 200 --expected 200") def testListWithNamespaceCli(self): class A(luigi.Task): task_namespace = "mynamespace" l_param = luigi.ListParameter(default=[1, 2, 3]) expected = luigi.ListParameter() def complete(self): if self.l_param != self.expected: raise ValueError return True self.assertTrue(self.run_locally_split("mynamespace.A --expected [1,2,3]")) self.assertTrue(self.run_locally_split("mynamespace.A --mynamespace.A-l [1,2,3] --expected [1,2,3]")) def testTupleWithNamespaceCli(self): class A(luigi.Task): task_namespace = "mynamespace" t = luigi.TupleParameter(default=((1, 2), (3, 4))) expected = luigi.TupleParameter() def complete(self): if self.t != self.expected: raise ValueError return True self.assertTrue(self.run_locally_split("mynamespace.A --expected ((1,2),(3,4))")) self.assertTrue(self.run_locally_split("mynamespace.A --mynamespace.A-t ((1,2),(3,4)) --expected ((1,2),(3,4))")) @with_config({"foo": {"bar": "[1,2,3]"}}) def testListConfig(self): self.assertTrue(_value(luigi.ListParameter(config_path=dict(section="foo", name="bar")))) @with_config({"foo": {"bar": "((1,2),(3,4))"}}) def testTupleConfig(self): self.assertTrue(_value(luigi.TupleParameter(config_path=dict(section="foo", name="bar")))) @with_config({"foo": {"bar": "-3"}}) def testNumericalParameter(self): p = luigi.NumericalParameter(min_value=-3, max_value=7, var_type=int, config_path=dict(section="foo", name="bar")) self.assertEqual(-3, _value(p)) @with_config({"foo": {"bar": "3"}}) def testChoiceParameter(self): p = luigi.ChoiceParameter(var_type=int, choices=[1, 2, 3], config_path=dict(section="foo", name="bar")) self.assertEqual(3, _value(p)) class OverrideEnvStuff(LuigiTestCase): @with_config({"core": {"default-scheduler-port": "6543"}}) def testOverrideSchedulerPort(self): with self.assertWarnsRegex(DeprecationWarning, r"default-scheduler-port is deprecated"): env_params = luigi.interface.core() self.assertEqual(env_params.scheduler_port, 6543) @with_config({"core": {"scheduler-port": "6544"}}) def testOverrideSchedulerPort2(self): with self.assertWarnsRegex(DeprecationWarning, r"scheduler-port \(with dashes\) should be avoided"): env_params = luigi.interface.core() self.assertEqual(env_params.scheduler_port, 6544) @with_config({"core": {"scheduler_port": "6545"}}) def testOverrideSchedulerPort3(self): env_params = luigi.interface.core() self.assertEqual(env_params.scheduler_port, 6545) class TestSerializeDateParameters(LuigiTestCase): def testSerialize(self): date = datetime.date(2013, 2, 3) self.assertEqual(luigi.DateParameter().serialize(date), "2013-02-03") self.assertEqual(luigi.YearParameter().serialize(date), "2013") self.assertEqual(luigi.MonthParameter().serialize(date), "2013-02") dt = datetime.datetime(2013, 2, 3, 4, 5) self.assertEqual(luigi.DateHourParameter().serialize(dt), "2013-02-03T04") class TestSerializeTimeDeltaParameters(LuigiTestCase): def testSerialize(self): tdelta = timedelta(weeks=5, days=4, hours=3, minutes=2, seconds=1) self.assertEqual(luigi.TimeDeltaParameter().serialize(tdelta), "5 w 4 d 3 h 2 m 1 s") tdelta = timedelta(seconds=0) self.assertEqual(luigi.TimeDeltaParameter().serialize(tdelta), "0 w 0 d 0 h 0 m 0 s") class TestTaskParameter(LuigiTestCase): def testUsage(self): class MetaTask(luigi.Task): task_namespace = "mynamespace" a = luigi.TaskParameter() def run(self): self.__class__.saved_value = self.a class OtherTask(luigi.Task): task_namespace = "other_namespace" self.assertEqual(MetaTask(a=MetaTask).a, MetaTask) self.assertEqual(MetaTask(a=OtherTask).a, OtherTask) # So I first thought this "should" work, but actually it should not, # because it should not need to parse values known at run-time self.assertRaises(AttributeError, lambda: MetaTask(a="mynamespace.MetaTask")) # But is should be able to parse command line arguments self.assertRaises(luigi.task_register.TaskClassNotFoundException, lambda: self.run_locally_split("mynamespace.MetaTask --a blah")) self.assertRaises(luigi.task_register.TaskClassNotFoundException, lambda: self.run_locally_split("mynamespace.MetaTask --a Taskk")) self.assertTrue(self.run_locally_split("mynamespace.MetaTask --a mynamespace.MetaTask")) self.assertEqual(MetaTask.saved_value, MetaTask) self.assertTrue(self.run_locally_split("mynamespace.MetaTask --a other_namespace.OtherTask")) self.assertEqual(MetaTask.saved_value, OtherTask) def testSerialize(self): class OtherTask(luigi.Task): def complete(self): return True class DepTask(luigi.Task): dep = luigi.TaskParameter() ran = False def complete(self): return self.__class__.ran def requires(self): return self.dep() def run(self): self.__class__.ran = True class MainTask(luigi.Task): def run(self): yield DepTask(dep=OtherTask) # OtherTask is serialized because it is used as an argument for DepTask. self.assertTrue(self.run_locally(["MainTask"])) class TestSerializeTupleParameter(LuigiTestCase): def testSerialize(self): the_tuple = (1, 2, 3) self.assertEqual(luigi.TupleParameter().parse(luigi.TupleParameter().serialize(the_tuple)), the_tuple) class NewStyleParameters822Test(LuigiTestCase): """ I bet these tests created at 2015-03-08 are reduntant by now (Oct 2015). But maintaining them anyway, just in case I have overlooked something. """ # See https://github.com/spotify/luigi/issues/822 def test_subclasses(self): class BarBaseClass(luigi.Task): x = luigi.Parameter(default="bar_base_default") class BarSubClass(BarBaseClass): pass in_parse(["BarSubClass", "--x", "xyz", "--BarBaseClass-x", "xyz"], lambda task: self.assertEqual(task.x, "xyz")) # https://github.com/spotify/luigi/issues/822#issuecomment-77782714 in_parse(["BarBaseClass", "--BarBaseClass-x", "xyz"], lambda task: self.assertEqual(task.x, "xyz")) class LocalParameters1304Test(LuigiTestCase): """ It was discussed and decided that local parameters (--x) should be semantically different from global parameters (--MyTask-x). The former sets only the parsed root task, and the later sets the parameter for all the tasks. https://github.com/spotify/luigi/issues/1304#issuecomment-148402284 """ def test_local_params(self): class MyTask(RunOnceTask): param1 = luigi.IntParameter() param2 = luigi.BoolParameter(default=False) def requires(self): if self.param1 > 0: yield MyTask(param1=(self.param1 - 1)) def run(self): assert self.param1 == 1 or not self.param2 self.comp = True self.assertTrue(self.run_locally_split("MyTask --param1 1 --param2")) def test_local_takes_precedence(self): class MyTask(luigi.Task): param = luigi.IntParameter() def complete(self): return False def run(self): assert self.param == 5 self.assertTrue(self.run_locally_split("MyTask --param 5 --MyTask-param 6")) def test_local_only_affects_root(self): class MyTask(RunOnceTask): param = luigi.IntParameter(default=3) def requires(self): assert self.param != 3 if self.param == 5: yield MyTask() # It would be a cyclic dependency if local took precedence self.assertTrue(self.run_locally_split("MyTask --param 5 --MyTask-param 6")) def test_range_doesnt_propagate_args(self): """ Ensure that ``--task Range --of Blah --blah-arg 123`` doesn't work. This will of course not work unless support is explicitly added for it. But being a bit paranoid here and adding this test case so that if somebody decides to add it in the future, they'll be redircted to the dicussion in #1304 """ class Blah(RunOnceTask): date = luigi.DateParameter() blah_arg = luigi.IntParameter() # The SystemExit is assumed to be thrown by argparse self.assertRaises(SystemExit, self.run_locally_split, "RangeDailyBase --of Blah --start 2015-01-01 --task-limit 1 --blah-arg 123") self.assertTrue(self.run_locally_split("RangeDailyBase --of Blah --start 2015-01-01 --task-limit 1 --Blah-blah-arg 123")) class TaskAsParameterName1335Test(LuigiTestCase): def test_parameter_can_be_named_task(self): class MyTask(luigi.Task): # Indeed, this is not the most realistic example, but still ... task = luigi.IntParameter() self.assertTrue(self.run_locally_split("MyTask --task 5")) class TestPathParameter: @pytest.fixture(params=[None, "not_existing_dir"]) def default(self, request): return request.param @pytest.fixture(params=[True, False]) def absolute(self, request): return request.param @pytest.fixture(params=[True, False]) def exists(self, request): return request.param @pytest.fixture() def path_parameter(self, tmpdir, default, absolute, exists): class TaskPathParameter(luigi.Task): a = luigi.PathParameter( default=str(tmpdir / default) if default is not None else str(tmpdir), absolute=absolute, exists=exists, ) b = luigi.OptionalPathParameter( default=str(tmpdir / default) if default is not None else str(tmpdir), absolute=absolute, exists=exists, ) c = luigi.OptionalPathParameter(default=None) d = luigi.OptionalPathParameter(default="not empty default") def run(self): # Use the parameter as a Path object new_file = self.a / "test.file" new_optional_file = self.b / "test_optional.file" if default is not None: new_file.parent.mkdir(parents=True) new_file.touch() new_optional_file.touch() assert new_file.exists() assert new_optional_file.exists() assert self.c is None assert self.d is None def output(self): return luigi.LocalTarget("not_existing_file") return { "tmpdir": tmpdir, "default": default, "absolute": absolute, "exists": exists, "cls": TaskPathParameter, } @with_config({"TaskPathParameter": {"d": ""}}) def test_exists(self, path_parameter): if path_parameter["default"] is not None and path_parameter["exists"]: with pytest.raises(ValueError, match="The path .* does not exist"): luigi.build([path_parameter["cls"]()], local_scheduler=True) else: assert luigi.build([path_parameter["cls"]()], local_scheduler=True) ================================================ FILE: test/priority_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from helpers import unittest import luigi import luigi.notifications luigi.notifications.DEBUG = True class PrioTask(luigi.Task): prio = luigi.Parameter() run_counter = 0 @property def priority(self): return self.prio def requires(self): if self.prio > 10: return PrioTask(self.prio - 10) def run(self): self.t = PrioTask.run_counter PrioTask.run_counter += 1 def complete(self): return hasattr(self, "t") class PriorityTest(unittest.TestCase): def test_priority(self): p, q, r = PrioTask(1), PrioTask(2), PrioTask(3) luigi.build([p, q, r], local_scheduler=True) self.assertTrue(r.t < q.t < p.t) def test_priority_w_dep(self): x, y, z = PrioTask(25), PrioTask(15), PrioTask(5) a, b, c = PrioTask(24), PrioTask(14), PrioTask(4) luigi.build([a, b, c, x, y, z], local_scheduler=True) self.assertTrue(z.t < y.t < x.t < c.t < b.t < a.t) ================================================ FILE: test/range_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import datetime import fnmatch import mock from helpers import LuigiTestCase, unittest import luigi from luigi.mock import MockFileSystem, MockTarget from luigi.tools.range import ( RangeByMinutes, RangeByMinutesBase, RangeDaily, RangeDailyBase, RangeEvent, RangeHourly, RangeHourlyBase, RangeMonthly, _constrain_glob, _get_filesystems_and_globs, ) class CommonDateMinuteTask(luigi.Task): dh = luigi.DateMinuteParameter() def output(self): return MockTarget(self.dh.strftime("/n2000y01a05n/%Y_%m-_-%daww/21mm%H%Mdara21/ooo")) class CommonDateHourTask(luigi.Task): dh = luigi.DateHourParameter() def output(self): return MockTarget(self.dh.strftime("/n2000y01a05n/%Y_%m-_-%daww/21mm%Hdara21/ooo")) class CommonDateTask(luigi.Task): d = luigi.DateParameter() def output(self): return MockTarget(self.d.strftime("/n2000y01a05n/%Y_%m-_-%daww/21mm01dara21/ooo")) class CommonMonthTask(luigi.Task): m = luigi.MonthParameter() def output(self): return MockTarget(self.m.strftime("/n2000y01a05n/%Y_%maww/21mm01dara21/ooo")) task_a_paths = [ "TaskA/2014-03-20/18", "TaskA/2014-03-20/21", "TaskA/2014-03-20/23", "TaskA/2014-03-21/00", "TaskA/2014-03-21/00.attempt.1", "TaskA/2014-03-21/00.attempt.2", "TaskA/2014-03-21/01", "TaskA/2014-03-21/02", "TaskA/2014-03-21/03.attempt-temp-2014-03-21T13-22-58.165969", "TaskA/2014-03-21/03.attempt.1", "TaskA/2014-03-21/03.attempt.2", "TaskA/2014-03-21/03.attempt.3", "TaskA/2014-03-21/03.attempt.latest", "TaskA/2014-03-21/04.attempt-temp-2014-03-21T13-23-09.078249", "TaskA/2014-03-21/12", "TaskA/2014-03-23/12", ] task_b_paths = [ "TaskB/no/worries2014-03-20/23", "TaskB/no/worries2014-03-21/01", "TaskB/no/worries2014-03-21/03", "TaskB/no/worries2014-03-21/04.attempt-yadayada", "TaskB/no/worries2014-03-21/05", ] mock_contents = task_a_paths + task_b_paths expected_a = [ "TaskA(dh=2014-03-20T17)", "TaskA(dh=2014-03-20T19)", "TaskA(dh=2014-03-20T20)", ] # expected_reverse = [ # ] expected_wrapper = [ "CommonWrapperTask(dh=2014-03-21T00)", "CommonWrapperTask(dh=2014-03-21T02)", "CommonWrapperTask(dh=2014-03-21T03)", "CommonWrapperTask(dh=2014-03-21T04)", "CommonWrapperTask(dh=2014-03-21T05)", ] class TaskA(luigi.Task): dh = luigi.DateHourParameter() def output(self): return MockTarget(self.dh.strftime("TaskA/%Y-%m-%d/%H")) class TaskB(luigi.Task): dh = luigi.DateHourParameter() complicator = luigi.Parameter() def output(self): return MockTarget(self.dh.strftime("TaskB/%%s%Y-%m-%d/%H") % self.complicator) class TaskC(luigi.Task): dh = luigi.DateHourParameter() def output(self): return MockTarget(self.dh.strftime("not/a/real/path/%Y-%m-%d/%H")) class CommonWrapperTask(luigi.WrapperTask): dh = luigi.DateHourParameter() def requires(self): yield TaskA(dh=self.dh) yield TaskB(dh=self.dh, complicator="no/worries") # str(self.dh) would complicate beyond working class TaskMinutesA(luigi.Task): dm = luigi.DateMinuteParameter() def output(self): return MockTarget(self.dm.strftime("TaskA/%Y-%m-%d/%H%M")) class TaskMinutesB(luigi.Task): dm = luigi.DateMinuteParameter() complicator = luigi.Parameter() def output(self): return MockTarget(self.dm.strftime("TaskB/%%s%Y-%m-%d/%H%M") % self.complicator) class TaskMinutesC(luigi.Task): dm = luigi.DateMinuteParameter() def output(self): return MockTarget(self.dm.strftime("not/a/real/path/%Y-%m-%d/%H%M")) class CommonWrapperTaskMinutes(luigi.WrapperTask): dm = luigi.DateMinuteParameter() def requires(self): yield TaskMinutesA(dm=self.dm) yield TaskMinutesB(dm=self.dm, complicator="no/worries") # str(self.dh) would complicate beyond working def mock_listdir(contents): def contents_listdir(_, glob): for path in fnmatch.filter(contents, glob + "*"): yield path return contents_listdir def mock_exists_always_true(_, _2): yield True def mock_exists_always_false(_, _2): yield False class ConstrainGlobTest(unittest.TestCase): def test_limit(self): glob = "/[0-9][0-9][0-9][0-9]/[0-9][0-9]/[0-9][0-9]/[0-9][0-9]" paths = [(datetime.datetime(2013, 12, 31, 5) + datetime.timedelta(hours=h)).strftime("/%Y/%m/%d/%H") for h in range(40)] self.assertEqual( sorted(_constrain_glob(glob, paths)), [ "/2013/12/31/[0-2][0-9]", "/2014/01/01/[0-2][0-9]", ], ) paths.pop(26) self.assertEqual( sorted(_constrain_glob(glob, paths, 6)), [ "/2013/12/31/0[5-9]", "/2013/12/31/1[0-9]", "/2013/12/31/2[0-3]", "/2014/01/01/0[012345689]", "/2014/01/01/1[0-9]", "/2014/01/01/2[0]", ], ) self.assertEqual( sorted(_constrain_glob(glob, paths[:7], 10)), [ "/2013/12/31/05", "/2013/12/31/06", "/2013/12/31/07", "/2013/12/31/08", "/2013/12/31/09", "/2013/12/31/10", "/2013/12/31/11", ], ) def test_no_wildcards(self): glob = "/2014/01" paths = "/2014/01" self.assertEqual( _constrain_glob(glob, paths), [ "/2014/01", ], ) def datetime_to_epoch(dt): td = dt - datetime.datetime(1970, 1, 1) return td.days * 86400 + td.seconds + td.microseconds / 1e6 class RangeDailyBaseTest(unittest.TestCase): maxDiff = None def setUp(self): # yucky to create separate callbacks; would be nicer if the callback # received an instance of a subclass of Event, so one callback could # accumulate all types @RangeDailyBase.event_handler(RangeEvent.DELAY) def callback_delay(*args): self.events.setdefault(RangeEvent.DELAY, []).append(args) @RangeDailyBase.event_handler(RangeEvent.COMPLETE_COUNT) def callback_complete_count(*args): self.events.setdefault(RangeEvent.COMPLETE_COUNT, []).append(args) @RangeDailyBase.event_handler(RangeEvent.COMPLETE_FRACTION) def callback_complete_fraction(*args): self.events.setdefault(RangeEvent.COMPLETE_FRACTION, []).append(args) self.events = {} def test_consistent_formatting(self): task = RangeDailyBase(of=CommonDateTask, start=datetime.date(2016, 1, 1)) self.assertEqual(task._format_range([datetime.datetime(2016, 1, 2, 13), datetime.datetime(2016, 2, 29, 23)]), "[2016-01-02, 2016-02-29]") def _empty_subcase(self, kwargs, expected_events): calls = [] class RangeDailyDerived(RangeDailyBase): def missing_datetimes(self, task_cls, finite_datetimes): args = [self, task_cls, finite_datetimes] calls.append(args) return args[-1][:5] task = RangeDailyDerived(of=CommonDateTask, **kwargs) self.assertEqual(task.requires(), []) self.assertEqual(calls, []) self.assertEqual(task.requires(), []) self.assertEqual(calls, []) # subsequent requires() should return the cached result, never call missing_datetimes self.assertEqual(self.events, expected_events) self.assertTrue(task.complete()) def test_stop_before_days_back(self): # nothing to do because stop is earlier self._empty_subcase( { "now": datetime_to_epoch(datetime.datetime(2015, 1, 1, 4)), "stop": datetime.date(2014, 3, 20), "days_back": 4, "days_forward": 20, "reverse": True, }, { "event.tools.range.delay": [ ("CommonDateTask", 0), ], "event.tools.range.complete.count": [ ("CommonDateTask", 0), ], "event.tools.range.complete.fraction": [ ("CommonDateTask", 1.0), ], }, ) def _nonempty_subcase(self, kwargs, expected_finite_datetimes_range, expected_requires, expected_events): calls = [] class RangeDailyDerived(RangeDailyBase): def missing_datetimes(self, finite_datetimes): # I only changed tests for number of arguments at this one # place to test both old and new behavior calls.append((self, finite_datetimes)) return finite_datetimes[:7] task = RangeDailyDerived(of=CommonDateTask, **kwargs) self.assertEqual(list(map(str, task.requires())), expected_requires) self.assertEqual((min(calls[0][1]), max(calls[0][1])), expected_finite_datetimes_range) self.assertEqual(list(map(str, task.requires())), expected_requires) self.assertEqual(len(calls), 1) # subsequent requires() should return the cached result, not call missing_datetimes again self.assertEqual(self.events, expected_events) self.assertFalse(task.complete()) def test_start_long_before_long_days_back_and_with_long_days_forward(self): self._nonempty_subcase( { "now": datetime_to_epoch(datetime.datetime(2017, 10, 22, 12, 4, 29)), "start": datetime.date(2011, 3, 20), "stop": datetime.date(2025, 1, 29), "task_limit": 4, "days_back": 3 * 365, "days_forward": 3 * 365, }, (datetime.datetime(2014, 10, 24), datetime.datetime(2020, 10, 21)), [ "CommonDateTask(d=2014-10-24)", "CommonDateTask(d=2014-10-25)", "CommonDateTask(d=2014-10-26)", "CommonDateTask(d=2014-10-27)", ], { "event.tools.range.delay": [ ("CommonDateTask", 3750), ], "event.tools.range.complete.count": [ ("CommonDateTask", 5057), ], "event.tools.range.complete.fraction": [ ("CommonDateTask", 5057.0 / (5057 + 7)), ], }, ) class RangeHourlyBaseTest(unittest.TestCase): maxDiff = None def setUp(self): # yucky to create separate callbacks; would be nicer if the callback # received an instance of a subclass of Event, so one callback could # accumulate all types @RangeHourlyBase.event_handler(RangeEvent.DELAY) def callback_delay(*args): self.events.setdefault(RangeEvent.DELAY, []).append(args) @RangeHourlyBase.event_handler(RangeEvent.COMPLETE_COUNT) def callback_complete_count(*args): self.events.setdefault(RangeEvent.COMPLETE_COUNT, []).append(args) @RangeHourlyBase.event_handler(RangeEvent.COMPLETE_FRACTION) def callback_complete_fraction(*args): self.events.setdefault(RangeEvent.COMPLETE_FRACTION, []).append(args) self.events = {} def test_consistent_formatting(self): task = RangeHourlyBase(of=CommonDateHourTask, start=datetime.datetime(2016, 1, 1)) self.assertEqual(task._format_range([datetime.datetime(2016, 1, 2, 13), datetime.datetime(2016, 2, 29, 23)]), "[2016-01-02T13, 2016-02-29T23]") def _empty_subcase(self, kwargs, expected_events): calls = [] class RangeHourlyDerived(RangeHourlyBase): def missing_datetimes(a, b, c): args = [a, b, c] calls.append(args) return args[-1][:5] task = RangeHourlyDerived(of=CommonDateHourTask, **kwargs) self.assertEqual(task.requires(), []) self.assertEqual(calls, []) self.assertEqual(task.requires(), []) self.assertEqual(calls, []) # subsequent requires() should return the cached result, never call missing_datetimes self.assertEqual(self.events, expected_events) self.assertTrue(task.complete()) def test_start_after_hours_forward(self): # nothing to do because start is later self._empty_subcase( { "now": datetime_to_epoch(datetime.datetime(2000, 1, 1, 4)), "start": datetime.datetime(2014, 3, 20, 17), "hours_back": 4, "hours_forward": 20, }, { "event.tools.range.delay": [ ("CommonDateHourTask", 0), ], "event.tools.range.complete.count": [ ("CommonDateHourTask", 0), ], "event.tools.range.complete.fraction": [ ("CommonDateHourTask", 1.0), ], }, ) def _nonempty_subcase(self, kwargs, expected_finite_datetimes_range, expected_requires, expected_events): calls = [] class RangeHourlyDerived(RangeHourlyBase): def missing_datetimes(a, b, c): args = [a, b, c] calls.append(args) return args[-1][:7] task = RangeHourlyDerived(of=CommonDateHourTask, **kwargs) self.assertEqual(list(map(str, task.requires())), expected_requires) self.assertEqual(calls[0][1], CommonDateHourTask) self.assertEqual((min(calls[0][2]), max(calls[0][2])), expected_finite_datetimes_range) self.assertEqual(list(map(str, task.requires())), expected_requires) self.assertEqual(len(calls), 1) # subsequent requires() should return the cached result, not call missing_datetimes again self.assertEqual(self.events, expected_events) self.assertFalse(task.complete()) def test_start_long_before_hours_back(self): self._nonempty_subcase( { "now": datetime_to_epoch(datetime.datetime(2000, 1, 1, 4)), "start": datetime.datetime(1960, 3, 2, 1), "hours_back": 5, "hours_forward": 20, }, (datetime.datetime(1999, 12, 31, 23), datetime.datetime(2000, 1, 1, 23)), [ "CommonDateHourTask(dh=1999-12-31T23)", "CommonDateHourTask(dh=2000-01-01T00)", "CommonDateHourTask(dh=2000-01-01T01)", "CommonDateHourTask(dh=2000-01-01T02)", "CommonDateHourTask(dh=2000-01-01T03)", "CommonDateHourTask(dh=2000-01-01T04)", "CommonDateHourTask(dh=2000-01-01T05)", ], { "event.tools.range.delay": [ ("CommonDateHourTask", 25), # because of short hours_back we're oblivious to those 40 preceding years ], "event.tools.range.complete.count": [ ("CommonDateHourTask", 349192), ], "event.tools.range.complete.fraction": [ ("CommonDateHourTask", 349192.0 / (349192 + 7)), ], }, ) def test_start_after_long_hours_back(self): self._nonempty_subcase( { "now": datetime_to_epoch(datetime.datetime(2014, 10, 22, 12, 4, 29)), "start": datetime.datetime(2014, 3, 20, 17), "task_limit": 4, "hours_back": 365 * 24, }, (datetime.datetime(2014, 3, 20, 17), datetime.datetime(2014, 10, 22, 12)), [ "CommonDateHourTask(dh=2014-03-20T17)", "CommonDateHourTask(dh=2014-03-20T18)", "CommonDateHourTask(dh=2014-03-20T19)", "CommonDateHourTask(dh=2014-03-20T20)", ], { "event.tools.range.delay": [ ("CommonDateHourTask", 5180), ], "event.tools.range.complete.count": [ ("CommonDateHourTask", 5173), ], "event.tools.range.complete.fraction": [ ("CommonDateHourTask", 5173.0 / (5173 + 7)), ], }, ) def test_start_long_before_long_hours_back_and_with_long_hours_forward(self): self._nonempty_subcase( { "now": datetime_to_epoch(datetime.datetime(2017, 10, 22, 12, 4, 29)), "start": datetime.datetime(2011, 3, 20, 17), "task_limit": 4, "hours_back": 3 * 365 * 24, "hours_forward": 3 * 365 * 24, }, (datetime.datetime(2014, 10, 23, 13), datetime.datetime(2020, 10, 21, 12)), [ "CommonDateHourTask(dh=2014-10-23T13)", "CommonDateHourTask(dh=2014-10-23T14)", "CommonDateHourTask(dh=2014-10-23T15)", "CommonDateHourTask(dh=2014-10-23T16)", ], { "event.tools.range.delay": [ ("CommonDateHourTask", 52560), ], "event.tools.range.complete.count": [ ("CommonDateHourTask", 84061), ], "event.tools.range.complete.fraction": [ ("CommonDateHourTask", 84061.0 / (84061 + 7)), ], }, ) class RangeByMinutesBaseTest(unittest.TestCase): maxDiff = None def setUp(self): # yucky to create separate callbacks; would be nicer if the callback # received an instance of a subclass of Event, so one callback could # accumulate all types @RangeByMinutesBase.event_handler(RangeEvent.DELAY) def callback_delay(*args): self.events.setdefault(RangeEvent.DELAY, []).append(args) @RangeByMinutesBase.event_handler(RangeEvent.COMPLETE_COUNT) def callback_complete_count(*args): self.events.setdefault(RangeEvent.COMPLETE_COUNT, []).append(args) @RangeByMinutesBase.event_handler(RangeEvent.COMPLETE_FRACTION) def callback_complete_fraction(*args): self.events.setdefault(RangeEvent.COMPLETE_FRACTION, []).append(args) self.events = {} def test_consistent_formatting(self): task = RangeByMinutesBase(of=CommonDateMinuteTask, start=datetime.datetime(2016, 1, 1, 13), minutes_interval=5) self.assertEqual( task._format_range([datetime.datetime(2016, 1, 2, 13, 10), datetime.datetime(2016, 2, 29, 23, 20)]), "[2016-01-02T1310, 2016-02-29T2320]" ) def _empty_subcase(self, kwargs, expected_events): calls = [] class RangeByMinutesDerived(RangeByMinutesBase): def missing_datetimes(a, b, c): args = [a, b, c] calls.append(args) return args[-1][:5] task = RangeByMinutesDerived(of=CommonDateMinuteTask, **kwargs) self.assertEqual(task.requires(), []) self.assertEqual(calls, []) self.assertEqual(task.requires(), []) self.assertEqual(calls, []) # subsequent requires() should return the cached result, never call missing_datetimes self.assertEqual(self.events, expected_events) self.assertTrue(task.complete()) def test_start_after_minutes_forward(self): # nothing to do because start is later self._empty_subcase( { "now": datetime_to_epoch(datetime.datetime(2000, 1, 1, 4)), "start": datetime.datetime(2014, 3, 20, 17, 10), "minutes_back": 4, "minutes_forward": 20, "minutes_interval": 5, }, { "event.tools.range.delay": [ ("CommonDateMinuteTask", 0), ], "event.tools.range.complete.count": [ ("CommonDateMinuteTask", 0), ], "event.tools.range.complete.fraction": [ ("CommonDateMinuteTask", 1.0), ], }, ) def _nonempty_subcase(self, kwargs, expected_finite_datetimes_range, expected_requires, expected_events): calls = [] class RangeByMinutesDerived(RangeByMinutesBase): def missing_datetimes(a, b, c): args = [a, b, c] calls.append(args) return args[-1][:7] task = RangeByMinutesDerived(of=CommonDateMinuteTask, **kwargs) self.assertEqual(list(map(str, task.requires())), expected_requires) self.assertEqual(calls[0][1], CommonDateMinuteTask) self.assertEqual((min(calls[0][2]), max(calls[0][2])), expected_finite_datetimes_range) self.assertEqual(list(map(str, task.requires())), expected_requires) self.assertEqual(len(calls), 1) # subsequent requires() should return the cached result, not call missing_datetimes again self.assertEqual(self.events, expected_events) self.assertFalse(task.complete()) def test_negative_interval(self): class SomeByMinutesTask(luigi.Task): d = luigi.DateMinuteParameter() def output(self): return MockTarget(self.d.strftime("/data/2014/p/v/z/%Y_/_%m-_-%doctor/20/%HZ%MOOO")) task = RangeByMinutes( now=datetime_to_epoch(datetime.datetime(2016, 4, 1)), of=SomeByMinutesTask, start=datetime.datetime(2014, 3, 20, 17), minutes_interval=-1 ) self.assertRaises(luigi.parameter.ParameterException, task.requires) def test_non_dividing_interval(self): class SomeByMinutesTask(luigi.Task): d = luigi.DateMinuteParameter() def output(self): return MockTarget(self.d.strftime("/data/2014/p/v/z/%Y_/_%m-_-%doctor/20/%HZ%MOOO")) task = RangeByMinutes( now=datetime_to_epoch(datetime.datetime(2016, 4, 1)), of=SomeByMinutesTask, start=datetime.datetime(2014, 3, 20, 17), minutes_interval=8 ) self.assertRaises(luigi.parameter.ParameterException, task.requires) def test_start_and_minutes_period(self): self._nonempty_subcase( { "now": datetime_to_epoch(datetime.datetime(2016, 9, 1, 12, 0, 0)), "start": datetime.datetime(2016, 9, 1, 11, 0, 0), "minutes_back": 24 * 60, "minutes_forward": 0, "minutes_interval": 3, }, (datetime.datetime(2016, 9, 1, 11, 0), datetime.datetime(2016, 9, 1, 11, 57, 0)), [ "CommonDateMinuteTask(dh=2016-09-01T1100)", "CommonDateMinuteTask(dh=2016-09-01T1103)", "CommonDateMinuteTask(dh=2016-09-01T1106)", "CommonDateMinuteTask(dh=2016-09-01T1109)", "CommonDateMinuteTask(dh=2016-09-01T1112)", "CommonDateMinuteTask(dh=2016-09-01T1115)", "CommonDateMinuteTask(dh=2016-09-01T1118)", ], { "event.tools.range.delay": [ ("CommonDateMinuteTask", 20), # First missing is the 20th ], "event.tools.range.complete.count": [ ("CommonDateMinuteTask", 13), # 20 intervals - 7 missing ], "event.tools.range.complete.fraction": [ ("CommonDateMinuteTask", 13.0 / (13 + 7)), # (expected - missing) / expected ], }, ) def test_start_long_before_minutes_back(self): self._nonempty_subcase( { "now": datetime_to_epoch(datetime.datetime(2000, 1, 1, 0, 3, 0)), "start": datetime.datetime(1960, 1, 1, 0, 0, 0), "minutes_back": 5, "minutes_forward": 20, "minutes_interval": 5, }, (datetime.datetime(2000, 1, 1, 0, 0), datetime.datetime(2000, 1, 1, 0, 20, 0)), [ "CommonDateMinuteTask(dh=2000-01-01T0000)", "CommonDateMinuteTask(dh=2000-01-01T0005)", "CommonDateMinuteTask(dh=2000-01-01T0010)", "CommonDateMinuteTask(dh=2000-01-01T0015)", "CommonDateMinuteTask(dh=2000-01-01T0020)", ], { "event.tools.range.delay": [ ("CommonDateMinuteTask", 5), # because of short minutes_back we're oblivious to those 40 preceding years ], "event.tools.range.complete.count": [ ("CommonDateMinuteTask", 4207680), # expected intervals - missing. ], "event.tools.range.complete.fraction": [ ("CommonDateMinuteTask", 4207680.0 / 4207685), # (expected - missing) / expected ], }, ) def test_start_after_long_minutes_back(self): self._nonempty_subcase( { "now": datetime_to_epoch(datetime.datetime(2014, 3, 20, 18, 4, 29)), "start": datetime.datetime(2014, 3, 20, 17, 10), "task_limit": 4, "minutes_back": 365 * 24 * 60, "minutes_interval": 5, }, (datetime.datetime(2014, 3, 20, 17, 10, 0), datetime.datetime(2014, 3, 20, 18, 0, 0)), [ "CommonDateMinuteTask(dh=2014-03-20T1710)", "CommonDateMinuteTask(dh=2014-03-20T1715)", "CommonDateMinuteTask(dh=2014-03-20T1720)", "CommonDateMinuteTask(dh=2014-03-20T1725)", ], { "event.tools.range.delay": [ ("CommonDateMinuteTask", 11), ], "event.tools.range.complete.count": [ ("CommonDateMinuteTask", 4), ], "event.tools.range.complete.fraction": [ ("CommonDateMinuteTask", 4.0 / 11), ], }, ) def test_start_long_before_long_minutes_back_and_with_long_minutes_forward(self): self._nonempty_subcase( { "now": datetime_to_epoch(datetime.datetime(2017, 3, 22, 20, 4, 29)), "start": datetime.datetime(2011, 3, 20, 17, 10, 0), "task_limit": 4, "minutes_back": 365 * 24 * 60, "minutes_forward": 365 * 24 * 60, "minutes_interval": 5, }, (datetime.datetime(2016, 3, 22, 20, 5), datetime.datetime(2018, 3, 22, 20, 0)), [ "CommonDateMinuteTask(dh=2016-03-22T2005)", "CommonDateMinuteTask(dh=2016-03-22T2010)", "CommonDateMinuteTask(dh=2016-03-22T2015)", "CommonDateMinuteTask(dh=2016-03-22T2020)", ], { "event.tools.range.delay": [ ("CommonDateMinuteTask", 210240), ], "event.tools.range.complete.count": [ ("CommonDateMinuteTask", 737020), ], "event.tools.range.complete.fraction": [ ("CommonDateMinuteTask", 737020.0 / (737020 + 7)), ], }, ) class FilesystemInferenceTest(unittest.TestCase): def _test_filesystems_and_globs(self, datetime_to_task, datetime_to_re, expected): actual = list(_get_filesystems_and_globs(datetime_to_task, datetime_to_re)) self.assertEqual(len(actual), len(expected)) for (actual_filesystem, actual_glob), (expected_filesystem, expected_glob) in zip(actual, expected): self.assertTrue(isinstance(actual_filesystem, expected_filesystem)) self.assertEqual(actual_glob, expected_glob) def test_date_glob_successfully_inferred(self): self._test_filesystems_and_globs( lambda d: CommonDateTask(d), lambda d: d.strftime("(%Y).*(%m).*(%d)"), [ (MockFileSystem, "/n2000y01a05n/[0-9][0-9][0-9][0-9]_[0-9][0-9]-_-[0-9][0-9]aww/21mm01dara21"), ], ) def test_datehour_glob_successfully_inferred(self): self._test_filesystems_and_globs( lambda d: CommonDateHourTask(d), lambda d: d.strftime("(%Y).*(%m).*(%d).*(%H)"), [ (MockFileSystem, "/n2000y01a05n/[0-9][0-9][0-9][0-9]_[0-9][0-9]-_-[0-9][0-9]aww/21mm[0-9][0-9]dara21"), ], ) def test_dateminute_glob_successfully_inferred(self): self._test_filesystems_and_globs( lambda d: CommonDateMinuteTask(d), lambda d: d.strftime("(%Y).*(%m).*(%d).*(%H).*(%M)"), [ (MockFileSystem, "/n2000y01a05n/[0-9][0-9][0-9][0-9]_[0-9][0-9]-_-[0-9][0-9]aww/21mm[0-9][0-9][0-9][0-9]dara21"), ], ) def test_wrapped_datehour_globs_successfully_inferred(self): self._test_filesystems_and_globs( lambda d: CommonWrapperTask(d), lambda d: d.strftime("(%Y).*(%m).*(%d).*(%H)"), [ (MockFileSystem, "TaskA/[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9]"), (MockFileSystem, "TaskB/no/worries[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9]"), ], ) def test_inconsistent_output_datehour_glob_not_inferred(self): class InconsistentlyOutputtingDateHourTask(luigi.Task): dh = luigi.DateHourParameter() def output(self): base = self.dh.strftime("/even/%Y%m%d%H") if self.dh.hour % 2 == 0: return MockTarget(base) else: return { "spi": MockTarget(base + "/something.spi"), "spl": MockTarget(base + "/something.spl"), } def test_raise_not_implemented(): list(_get_filesystems_and_globs(lambda d: InconsistentlyOutputtingDateHourTask(d), lambda d: d.strftime("(%Y).*(%m).*(%d).*(%H)"))) self.assertRaises(NotImplementedError, test_raise_not_implemented) def test_wrapped_inconsistent_datehour_globs_not_inferred(self): class InconsistentlyParameterizedWrapperTask(luigi.WrapperTask): dh = luigi.DateHourParameter() def requires(self): yield TaskA(dh=self.dh - datetime.timedelta(days=1)) yield TaskB(dh=self.dh, complicator="no/worries") def test_raise_not_implemented(): list(_get_filesystems_and_globs(lambda d: InconsistentlyParameterizedWrapperTask(d), lambda d: d.strftime("(%Y).*(%m).*(%d).*(%H)"))) self.assertRaises(NotImplementedError, test_raise_not_implemented) class RangeMonthlyTest(unittest.TestCase): def setUp(self): # yucky to create separate callbacks; would be nicer if the callback # received an instance of a subclass of Event, so one callback could # accumulate all types @RangeMonthly.event_handler(RangeEvent.DELAY) def callback_delay(*args): self.events.setdefault(RangeEvent.DELAY, []).append(args) @RangeMonthly.event_handler(RangeEvent.COMPLETE_COUNT) def callback_complete_count(*args): self.events.setdefault(RangeEvent.COMPLETE_COUNT, []).append(args) @RangeMonthly.event_handler(RangeEvent.COMPLETE_FRACTION) def callback_complete_fraction(*args): self.events.setdefault(RangeEvent.COMPLETE_FRACTION, []).append(args) self.events = {} def _empty_subcase(self, kwargs, expected_events): calls = [] class RangeMonthlyDerived(RangeMonthly): def missing_datetimes(self, task_cls, finite_datetimes): args = [self, task_cls, finite_datetimes] calls.append(args) return args[-1][:5] task = RangeMonthlyDerived(of=CommonMonthTask, **kwargs) self.assertEqual(task.requires(), []) self.assertEqual(calls, []) self.assertEqual(task.requires(), []) self.assertEqual(calls, []) # subsequent requires() should return the cached result, never call missing_datetimes self.assertEqual(self.events, expected_events) self.assertTrue(task.complete()) def test_stop_before_months_back(self): # nothing to do because stop is earlier self._empty_subcase( { "now": datetime_to_epoch(datetime.datetime(2017, 1, 3)), "stop": datetime.date(2016, 3, 20), "months_back": 4, "months_forward": 20, "reverse": True, }, { "event.tools.range.delay": [ ("CommonMonthTask", 0), ], "event.tools.range.complete.count": [ ("CommonMonthTask", 0), ], "event.tools.range.complete.fraction": [ ("CommonMonthTask", 1.0), ], }, ) def test_start_after_months_forward(self): # nothing to do because start is later self._empty_subcase( { "now": datetime_to_epoch(datetime.datetime(2000, 1, 1)), "start": datetime.datetime(2014, 3, 20), "months_back": 4, "months_forward": 20, }, { "event.tools.range.delay": [ ("CommonMonthTask", 0), ], "event.tools.range.complete.count": [ ("CommonMonthTask", 0), ], "event.tools.range.complete.fraction": [ ("CommonMonthTask", 1.0), ], }, ) def _nonempty_subcase(self, kwargs, expected_finite_datetimes_range, expected_requires, expected_events): calls = [] class RangeDailyDerived(RangeMonthly): def missing_datetimes(self, finite_datetimes): calls.append((self, finite_datetimes)) return finite_datetimes[:7] task = RangeDailyDerived(of=CommonMonthTask, **kwargs) self.assertEqual(list(map(str, task.requires())), expected_requires) self.assertEqual((min(calls[0][1]), max(calls[0][1])), expected_finite_datetimes_range) self.assertEqual(list(map(str, task.requires())), expected_requires) self.assertEqual(len(calls), 1) # subsequent requires() should return the cached result, not call missing_datetimes again self.assertEqual(self.events, expected_events) self.assertFalse(task.complete()) def test_start_long_before_months_back(self): total = (2000 - 1960) * 12 + 20 - 2 self._nonempty_subcase( { "now": datetime_to_epoch(datetime.datetime(2000, 1, 1)), "start": datetime.datetime(1960, 3, 2, 1), "months_back": 5, "months_forward": 20, }, (datetime.datetime(1999, 8, 1), datetime.datetime(2001, 8, 1)), [ "CommonMonthTask(m=1999-08)", "CommonMonthTask(m=1999-09)", "CommonMonthTask(m=1999-10)", "CommonMonthTask(m=1999-11)", "CommonMonthTask(m=1999-12)", "CommonMonthTask(m=2000-01)", "CommonMonthTask(m=2000-02)", ], { "event.tools.range.delay": [ ("CommonMonthTask", 25), ], "event.tools.range.complete.count": [ ("CommonMonthTask", total - 7), ], "event.tools.range.complete.fraction": [ ("CommonMonthTask", (total - 7.0) / total), ], }, ) def test_start_after_long_months_back(self): total = 12 - 4 self._nonempty_subcase( { "now": datetime_to_epoch(datetime.datetime(2014, 11, 22)), "start": datetime.datetime(2014, 3, 1), "task_limit": 4, "months_back": 12 * 24, }, (datetime.datetime(2014, 3, 1), datetime.datetime(2014, 10, 1)), [ "CommonMonthTask(m=2014-03)", "CommonMonthTask(m=2014-04)", "CommonMonthTask(m=2014-05)", "CommonMonthTask(m=2014-06)", ], { "event.tools.range.delay": [ ("CommonMonthTask", total), ], "event.tools.range.complete.count": [ ("CommonMonthTask", total - 7), ], "event.tools.range.complete.fraction": [ ("CommonMonthTask", (total - 7.0) / total), ], }, ) def test_start_long_before_long_months_back_and_with_long_months_forward(self): total = (2025 - 2011) * 12 - 2 self._nonempty_subcase( { "now": datetime_to_epoch(datetime.datetime(2017, 10, 22, 12, 4, 29)), "start": datetime.date(2011, 3, 20), "stop": datetime.date(2025, 1, 29), "task_limit": 4, "months_back": 3 * 12, "months_forward": 3 * 12, }, (datetime.datetime(2014, 10, 1), datetime.datetime(2020, 9, 1)), [ "CommonMonthTask(m=2014-10)", "CommonMonthTask(m=2014-11)", "CommonMonthTask(m=2014-12)", "CommonMonthTask(m=2015-01)", ], { "event.tools.range.delay": [ ("CommonMonthTask", (2025 - (2017 - 3)) * 12 - 9), ], "event.tools.range.complete.count": [ ("CommonMonthTask", total - 7), ], "event.tools.range.complete.fraction": [ ("CommonMonthTask", (total - 7.0) / total), ], }, ) def test_zero_months_forward(self): total = (2017 - 2011) * 12 self._nonempty_subcase( { "now": datetime_to_epoch(datetime.datetime(2017, 10, 31, 12, 4, 29)), "start": datetime.date(2011, 10, 1), "task_limit": 10, "months_back": 4, }, (datetime.datetime(2017, 6, 1), datetime.datetime(2017, 9, 1)), [ "CommonMonthTask(m=2017-06)", "CommonMonthTask(m=2017-07)", "CommonMonthTask(m=2017-08)", "CommonMonthTask(m=2017-09)", ], { "event.tools.range.delay": [ ("CommonMonthTask", 4), ], "event.tools.range.complete.count": [ ("CommonMonthTask", total - 4), ], "event.tools.range.complete.fraction": [ ("CommonMonthTask", (total - 4.0) / total), ], }, ) def test_months_forward_on_first_of_month(self): total = (2017 - 2011) * 12 + 2 self._nonempty_subcase( { "now": datetime_to_epoch(datetime.datetime(2017, 10, 1, 12, 4, 29)), "start": datetime.date(2011, 10, 1), "task_limit": 10, "months_back": 4, "months_forward": 2, }, (datetime.datetime(2017, 6, 1), datetime.datetime(2017, 11, 1)), [ "CommonMonthTask(m=2017-06)", "CommonMonthTask(m=2017-07)", "CommonMonthTask(m=2017-08)", "CommonMonthTask(m=2017-09)", "CommonMonthTask(m=2017-10)", "CommonMonthTask(m=2017-11)", ], { "event.tools.range.delay": [ ("CommonMonthTask", 6), ], "event.tools.range.complete.count": [ ("CommonMonthTask", total - 6), ], "event.tools.range.complete.fraction": [ ("CommonMonthTask", (total - 6.0) / total), ], }, ) def test_consistent_formatting(self): task = RangeMonthly(of=CommonMonthTask, start=datetime.date(2018, 1, 4)) self.assertEqual(task._format_range([datetime.datetime(2018, 2, 3, 14), datetime.datetime(2018, 4, 5, 21)]), "[2018-02, 2018-04]") class MonthInstantiationTest(LuigiTestCase): def test_old_month_instantiation(self): """ Verify that you can still programmatically set of param as string """ class MyTask(luigi.Task): month_param = luigi.MonthParameter() def complete(self): return False range_task = RangeMonthly( now=datetime_to_epoch(datetime.datetime(2016, 1, 1)), of=MyTask, start=datetime.date(2015, 12, 1), stop=datetime.date(2016, 1, 1) ) expected_task = MyTask(month_param=datetime.date(2015, 12, 1)) self.assertEqual(expected_task, list(range_task._requires())[0]) def test_month_cli_instantiation(self): """ Verify that you can still use Range through CLI """ class MyTask(luigi.Task): task_namespace = "wohoo" month_param = luigi.MonthParameter() secret = "some-value-to-sooth-python-linters" comp = False def complete(self): return self.comp def run(self): self.comp = True MyTask.secret = "yay" now = str(int(datetime_to_epoch(datetime.datetime(2016, 1, 1)))) self.run_locally_split("RangeMonthly --of wohoo.MyTask --now {now} --start 2015-12 --stop 2016-01".format(now=now)) self.assertEqual(MyTask(month_param=datetime.date(1934, 12, 1)).secret, "yay") def test_param_name(self): class MyTask(luigi.Task): some_non_range_param = luigi.Parameter(default="woo") month_param = luigi.MonthParameter() def complete(self): return False range_task = RangeMonthly( now=datetime_to_epoch(datetime.datetime(2016, 1, 1)), of=MyTask, start=datetime.date(2015, 12, 1), stop=datetime.date(2016, 1, 1), param_name="month_param", ) expected_task = MyTask("woo", datetime.date(2015, 12, 1)) self.assertEqual(expected_task, list(range_task._requires())[0]) def test_param_name_with_inferred_fs(self): class MyTask(luigi.Task): some_non_range_param = luigi.Parameter(default="woo") month_param = luigi.MonthParameter() def output(self): return MockTarget(self.month_param.strftime("/n2000y01a05n/%Y_%m-aww/21mm%Hdara21/ooo")) range_task = RangeMonthly( now=datetime_to_epoch(datetime.datetime(2016, 1, 1)), of=MyTask, start=datetime.date(2015, 12, 1), stop=datetime.date(2016, 1, 1), param_name="month_param", ) expected_task = MyTask("woo", datetime.date(2015, 12, 1)) self.assertEqual(expected_task, list(range_task._requires())[0]) def test_of_param_distinction(self): class MyTask(luigi.Task): arbitrary_param = luigi.Parameter(default="foo") arbitrary_integer_param = luigi.IntParameter(default=10) month_param = luigi.MonthParameter() def complete(self): return False range_task_1 = RangeMonthly( now=datetime_to_epoch(datetime.datetime(2015, 12, 2)), of=MyTask, start=datetime.date(2015, 12, 1), stop=datetime.date(2016, 1, 1) ) range_task_2 = RangeMonthly( now=datetime_to_epoch(datetime.datetime(2015, 12, 2)), of=MyTask, of_params=dict(arbitrary_param="bar", abitrary_integer_param=2), start=datetime.date(2015, 12, 1), stop=datetime.date(2016, 1, 1), ) self.assertNotEqual(range_task_1.task_id, range_task_2.task_id) def test_of_param_commandline(self): class MyTask(luigi.Task): task_namespace = "wohoo" month_param = luigi.MonthParameter() arbitrary_param = luigi.Parameter(default="foo") arbitrary_integer_param = luigi.IntParameter(default=10) state = (None, None) comp = False def complete(self): return self.comp def run(self): self.comp = True MyTask.state = (self.arbitrary_param, self.arbitrary_integer_param) now = str(int(datetime_to_epoch(datetime.datetime(2016, 1, 1)))) self.run_locally( [ "RangeMonthly", "--of", "wohoo.MyTask", "--of-params", '{"arbitrary_param":"bar","arbitrary_integer_param":5}', "--now", "{0}".format(now), "--start", "2015-12", "--stop", "2016-01", ] ) self.assertEqual(MyTask.state, ("bar", 5)) class RangeDailyTest(unittest.TestCase): def test_bulk_complete_correctly_interfaced(self): class BulkCompleteDailyTask(luigi.Task): d = luigi.DateParameter() @classmethod def bulk_complete(self, parameter_tuples): return list(parameter_tuples)[:-2] def output(self): raise RuntimeError("Shouldn't get called while resolving deps via bulk_complete") task = RangeDaily( now=datetime_to_epoch(datetime.datetime(2015, 12, 1)), of=BulkCompleteDailyTask, start=datetime.date(2015, 11, 1), stop=datetime.date(2015, 12, 1) ) expected = [ "BulkCompleteDailyTask(d=2015-11-29)", "BulkCompleteDailyTask(d=2015-11-30)", ] actual = [str(t) for t in task.requires()] self.assertEqual(actual, expected) def test_bulk_complete_of_params(self): class BulkCompleteDailyTask(luigi.Task): non_positional_arbitrary_argument = luigi.Parameter(default="whatever", positional=False, significant=False) d = luigi.DateParameter() arbitrary_argument = luigi.BoolParameter() @classmethod def bulk_complete(cls, parameter_tuples): ptuples = list(parameter_tuples) for t in map(cls, ptuples): assert t.arbitrary_argument return ptuples[:-2] def output(self): raise RuntimeError("Shouldn't get called while resolving deps via bulk_complete") task = RangeDaily( now=datetime_to_epoch(datetime.datetime(2015, 12, 1)), of=BulkCompleteDailyTask, of_params=dict(arbitrary_argument=True), start=datetime.date(2015, 11, 1), stop=datetime.date(2015, 12, 1), ) expected = [ "BulkCompleteDailyTask(d=2015-11-29, arbitrary_argument=True)", "BulkCompleteDailyTask(d=2015-11-30, arbitrary_argument=True)", ] actual = [str(t) for t in task.requires()] self.assertEqual(actual, expected) @mock.patch( "luigi.mock.MockFileSystem.listdir", new=mock_listdir( [ "/data/2014/p/v/z/2014_/_03-_-21octor/20/ZOOO", "/data/2014/p/v/z/2014_/_03-_-23octor/20/ZOOO", "/data/2014/p/v/z/2014_/_03-_-24octor/20/ZOOO", ] ), ) @mock.patch("luigi.mock.MockFileSystem.exists", new=mock_exists_always_true) def test_missing_tasks_correctly_required(self): class SomeDailyTask(luigi.Task): d = luigi.DateParameter() def output(self): return MockTarget(self.d.strftime("/data/2014/p/v/z/%Y_/_%m-_-%doctor/20/ZOOO")) task = RangeDaily( now=datetime_to_epoch(datetime.datetime(2016, 4, 1)), of=SomeDailyTask, start=datetime.date(2014, 3, 20), task_limit=3, days_back=3 * 365 ) expected = [ "SomeDailyTask(d=2014-03-20)", "SomeDailyTask(d=2014-03-22)", "SomeDailyTask(d=2014-03-25)", ] actual = [str(t) for t in task.requires()] self.assertEqual(actual, expected) class RangeHourlyTest(unittest.TestCase): # fishy to mock the mock, but MockFileSystem doesn't support globs yet @mock.patch("luigi.mock.MockFileSystem.listdir", new=mock_listdir(mock_contents)) @mock.patch("luigi.mock.MockFileSystem.exists", new=mock_exists_always_true) def test_missing_tasks_correctly_required(self): for task_path in task_a_paths: MockTarget(task_path) # this test takes a few seconds. Since stop is not defined, # finite_datetimes constitute many years to consider task = RangeHourly( now=datetime_to_epoch(datetime.datetime(2016, 4, 1)), of=TaskA, start=datetime.datetime(2014, 3, 20, 17), task_limit=3, hours_back=3 * 365 * 24 ) actual = [str(t) for t in task.requires()] self.assertEqual(actual, expected_a) @mock.patch("luigi.mock.MockFileSystem.listdir", new=mock_listdir(mock_contents)) @mock.patch("luigi.mock.MockFileSystem.exists", new=mock_exists_always_true) def test_missing_wrapper_tasks_correctly_required(self): task = RangeHourly( now=datetime_to_epoch(datetime.datetime(2040, 4, 1)), of=CommonWrapperTask, start=datetime.datetime(2014, 3, 20, 23), stop=datetime.datetime(2014, 3, 21, 6), hours_back=30 * 365 * 24, ) actual = [str(t) for t in task.requires()] self.assertEqual(actual, expected_wrapper) def test_bulk_complete_correctly_interfaced(self): class BulkCompleteHourlyTask(luigi.Task): dh = luigi.DateHourParameter() @classmethod def bulk_complete(cls, parameter_tuples): return parameter_tuples[:-2] def output(self): raise RuntimeError("Shouldn't get called while resolving deps via bulk_complete") task = RangeHourly( now=datetime_to_epoch(datetime.datetime(2015, 12, 1)), of=BulkCompleteHourlyTask, start=datetime.datetime(2015, 11, 1), stop=datetime.datetime(2015, 12, 1), ) expected = [ "BulkCompleteHourlyTask(dh=2015-11-30T22)", "BulkCompleteHourlyTask(dh=2015-11-30T23)", ] actual = [str(t) for t in task.requires()] self.assertEqual(actual, expected) def test_bulk_complete_of_params(self): class BulkCompleteHourlyTask(luigi.Task): non_positional_arbitrary_argument = luigi.Parameter(default="whatever", positional=False, significant=False) dh = luigi.DateHourParameter() arbitrary_argument = luigi.BoolParameter() @classmethod def bulk_complete(cls, parameter_tuples): for t in map(cls, parameter_tuples): assert t.arbitrary_argument return parameter_tuples[:-2] def output(self): raise RuntimeError("Shouldn't get called while resolving deps via bulk_complete") task = RangeHourly( now=datetime_to_epoch(datetime.datetime(2015, 12, 1)), of=BulkCompleteHourlyTask, of_params=dict(arbitrary_argument=True), start=datetime.datetime(2015, 11, 1), stop=datetime.datetime(2015, 12, 1), ) expected = [ "BulkCompleteHourlyTask(dh=2015-11-30T22, arbitrary_argument=True)", "BulkCompleteHourlyTask(dh=2015-11-30T23, arbitrary_argument=True)", ] actual = [str(t) for t in task.requires()] self.assertEqual(actual, expected) @mock.patch("luigi.mock.MockFileSystem.exists", new=mock_exists_always_false) def test_missing_directory(self): task = RangeHourly( now=datetime_to_epoch(datetime.datetime(2014, 4, 1)), of=TaskC, start=datetime.datetime(2014, 3, 20, 23), stop=datetime.datetime(2014, 3, 21, 1) ) self.assertFalse(task.complete()) expected = ["TaskC(dh=2014-03-20T23)", "TaskC(dh=2014-03-21T00)"] self.assertEqual([str(t) for t in task.requires()], expected) class RangeByMinutesTest(unittest.TestCase): # fishy to mock the mock, but MockFileSystem doesn't support globs yet @mock.patch("luigi.mock.MockFileSystem.listdir", new=mock_listdir(mock_contents)) @mock.patch("luigi.mock.MockFileSystem.exists", new=mock_exists_always_true) def test_missing_tasks_correctly_required(self): expected_tasks = ["SomeByMinutesTask(d=2016-03-31T0000)", "SomeByMinutesTask(d=2016-03-31T0005)", "SomeByMinutesTask(d=2016-03-31T0010)"] class SomeByMinutesTask(luigi.Task): d = luigi.DateMinuteParameter() def output(self): return MockTarget(self.d.strftime("/data/2014/p/v/z/%Y_/_%m-_-%doctor/20/%HZ%MOOO")) for task_path in task_a_paths: MockTarget(task_path) # this test takes a few seconds. Since stop is not defined, # finite_datetimes constitute many years to consider task = RangeByMinutes( now=datetime_to_epoch(datetime.datetime(2016, 4, 1)), of=SomeByMinutesTask, start=datetime.datetime(2014, 3, 20, 17), task_limit=3, minutes_back=24 * 60, minutes_interval=5, ) actual = [str(t) for t in task.requires()] self.assertEqual(actual, expected_tasks) @mock.patch("luigi.mock.MockFileSystem.listdir", new=mock_listdir(mock_contents)) @mock.patch("luigi.mock.MockFileSystem.exists", new=mock_exists_always_true) def test_missing_wrapper_tasks_correctly_required(self): expected_wrapper = [ "CommonWrapperTaskMinutes(dm=2014-03-20T2300)", "CommonWrapperTaskMinutes(dm=2014-03-20T2305)", "CommonWrapperTaskMinutes(dm=2014-03-20T2310)", "CommonWrapperTaskMinutes(dm=2014-03-20T2315)", ] task = RangeByMinutes( now=datetime_to_epoch(datetime.datetime(2040, 4, 1, 0, 0, 0)), of=CommonWrapperTaskMinutes, start=datetime.datetime(2014, 3, 20, 23, 0, 0), stop=datetime.datetime(2014, 3, 20, 23, 20, 0), minutes_back=30 * 365 * 24 * 60, minutes_interval=5, ) actual = [str(t) for t in task.requires()] self.assertEqual(actual, expected_wrapper) def test_bulk_complete_correctly_interfaced(self): class BulkCompleteByMinutesTask(luigi.Task): dh = luigi.DateMinuteParameter() @classmethod def bulk_complete(cls, parameter_tuples): return list(parameter_tuples)[:-2] def output(self): raise RuntimeError("Shouldn't get called while resolving deps via bulk_complete") task = RangeByMinutes( now=datetime_to_epoch(datetime.datetime(2015, 12, 1)), of=BulkCompleteByMinutesTask, start=datetime.datetime(2015, 11, 1), stop=datetime.datetime(2015, 12, 1), minutes_interval=5, ) expected = [ "BulkCompleteByMinutesTask(dh=2015-11-30T2350)", "BulkCompleteByMinutesTask(dh=2015-11-30T2355)", ] actual = [str(t) for t in task.requires()] self.assertEqual(actual, expected) def test_bulk_complete_of_params(self): class BulkCompleteByMinutesTask(luigi.Task): non_positional_arbitrary_argument = luigi.Parameter(default="whatever", positional=False, significant=False) dh = luigi.DateMinuteParameter() arbitrary_argument = luigi.BoolParameter() @classmethod def bulk_complete(cls, parameter_tuples): ptuples = list(parameter_tuples) for t in map(cls, parameter_tuples): assert t.arbitrary_argument return ptuples[:-2] def output(self): raise RuntimeError("Shouldn't get called while resolving deps via bulk_complete") task = RangeByMinutes( now=datetime_to_epoch(datetime.datetime(2015, 12, 1)), of=BulkCompleteByMinutesTask, of_params=dict(arbitrary_argument=True), start=datetime.datetime(2015, 11, 1), stop=datetime.datetime(2015, 12, 1), minutes_interval=5, ) expected = [ "BulkCompleteByMinutesTask(dh=2015-11-30T2350, arbitrary_argument=True)", "BulkCompleteByMinutesTask(dh=2015-11-30T2355, arbitrary_argument=True)", ] actual = [str(t) for t in task.requires()] self.assertEqual(actual, expected) @mock.patch("luigi.mock.MockFileSystem.exists", new=mock_exists_always_false) def test_missing_directory(self): task = RangeByMinutes( now=datetime_to_epoch(datetime.datetime(2014, 3, 21, 0, 0)), of=TaskMinutesC, start=datetime.datetime(2014, 3, 20, 23, 11), stop=datetime.datetime(2014, 3, 20, 23, 21), minutes_interval=5, ) self.assertFalse(task.complete()) expected = ["TaskMinutesC(dm=2014-03-20T2315)", "TaskMinutesC(dm=2014-03-20T2320)"] self.assertEqual([str(t) for t in task.requires()], expected) class RangeInstantiationTest(LuigiTestCase): def test_old_instantiation(self): """ Verify that you can still programmatically set of param as string """ class MyTask(luigi.Task): date_param = luigi.DateParameter() def complete(self): return False range_task = RangeDailyBase( now=datetime_to_epoch(datetime.datetime(2015, 12, 2)), of=MyTask, start=datetime.date(2015, 12, 1), stop=datetime.date(2015, 12, 2) ) expected_task = MyTask(date_param=datetime.date(2015, 12, 1)) self.assertEqual(expected_task, list(range_task._requires())[0]) def test_cli_instantiation(self): """ Verify that you can still use Range through CLI """ class MyTask(luigi.Task): task_namespace = "wohoo" date_param = luigi.DateParameter() secret = "some-value-to-sooth-python-linters" comp = False def complete(self): return self.comp def run(self): self.comp = True MyTask.secret = "yay" now = str(int(datetime_to_epoch(datetime.datetime(2015, 12, 2)))) self.run_locally_split("RangeDailyBase --of wohoo.MyTask --now {now} --start 2015-12-01 --stop 2015-12-02".format(now=now)) self.assertEqual(MyTask(date_param=datetime.date(1934, 12, 1)).secret, "yay") def test_param_name(self): class MyTask(luigi.Task): some_non_range_param = luigi.Parameter(default="woo") date_param = luigi.DateParameter() def complete(self): return False range_task = RangeDailyBase( now=datetime_to_epoch(datetime.datetime(2015, 12, 2)), of=MyTask, start=datetime.date(2015, 12, 1), stop=datetime.date(2015, 12, 2), param_name="date_param", ) expected_task = MyTask("woo", datetime.date(2015, 12, 1)) self.assertEqual(expected_task, list(range_task._requires())[0]) def test_param_name_with_inferred_fs(self): class MyTask(luigi.Task): some_non_range_param = luigi.Parameter(default="woo") date_param = luigi.DateParameter() def output(self): return MockTarget(self.date_param.strftime("/n2000y01a05n/%Y_%m-_-%daww/21mm%Hdara21/ooo")) range_task = RangeDaily( now=datetime_to_epoch(datetime.datetime(2015, 12, 2)), of=MyTask, start=datetime.date(2015, 12, 1), stop=datetime.date(2015, 12, 2), param_name="date_param", ) expected_task = MyTask("woo", datetime.date(2015, 12, 1)) self.assertEqual(expected_task, list(range_task._requires())[0]) def test_of_param_distinction(self): class MyTask(luigi.Task): arbitrary_param = luigi.Parameter(default="foo") arbitrary_integer_param = luigi.IntParameter(default=10) date_param = luigi.DateParameter() def complete(self): return False range_task_1 = RangeDaily( now=datetime_to_epoch(datetime.datetime(2015, 12, 2)), of=MyTask, start=datetime.date(2015, 12, 1), stop=datetime.date(2015, 12, 2) ) range_task_2 = RangeDaily( now=datetime_to_epoch(datetime.datetime(2015, 12, 2)), of=MyTask, of_params=dict(arbitrary_param="bar", abitrary_integer_param=2), start=datetime.date(2015, 12, 1), stop=datetime.date(2015, 12, 2), ) self.assertNotEqual(range_task_1.task_id, range_task_2.task_id) def test_of_param_commandline(self): class MyTask(luigi.Task): task_namespace = "wohoo" date_param = luigi.DateParameter() arbitrary_param = luigi.Parameter(default="foo") arbitrary_integer_param = luigi.IntParameter(default=10) state = (None, None) comp = False def complete(self): return self.comp def run(self): self.comp = True MyTask.state = (self.arbitrary_param, self.arbitrary_integer_param) now = str(int(datetime_to_epoch(datetime.datetime(2015, 12, 2)))) self.run_locally( [ "RangeDailyBase", "--of", "wohoo.MyTask", "--of-params", '{"arbitrary_param":"bar","arbitrary_integer_param":5}', "--now", "{0}".format(now), "--start", "2015-12-01", "--stop", "2015-12-02", ] ) self.assertEqual(MyTask.state, ("bar", 5)) ================================================ FILE: test/recursion_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import datetime from helpers import unittest import luigi import luigi.interface from luigi.mock import MockTarget class Popularity(luigi.Task): date = luigi.DateParameter(default=datetime.date.today() - datetime.timedelta(1)) def output(self): return MockTarget("/tmp/popularity/%s.txt" % self.date.strftime("%Y-%m-%d")) def requires(self): return Popularity(self.date - datetime.timedelta(1)) def run(self): f = self.output().open("w") for line in self.input().open("r"): print(int(line.strip()) + 1, file=f) f.close() class RecursionTest(unittest.TestCase): def setUp(self): MockTarget.fs.get_all_data()["/tmp/popularity/2009-01-01.txt"] = b"0\n" def test_invoke(self): luigi.build([Popularity(datetime.date(2009, 1, 5))], local_scheduler=True) self.assertEqual(MockTarget.fs.get_data("/tmp/popularity/2009-01-05.txt"), b"4\n") ================================================ FILE: test/remote_scheduler_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import os import tempfile import server_test import luigi.server tempdir = tempfile.mkdtemp() class DummyTask(luigi.Task): id = luigi.IntParameter() def run(self): f = self.output().open("w") f.close() def output(self): return luigi.LocalTarget(os.path.join(tempdir, str(self.id))) class RemoteSchedulerTest(server_test.ServerTestBase): def _test_run(self, workers): tasks = [DummyTask(id) for id in range(20)] luigi.build(tasks, workers=workers, scheduler_port=self.get_http_port()) for t in tasks: self.assertEqual(t.complete(), True) self.assertTrue(os.path.exists(t.output().path)) def test_single_worker(self): self._test_run(workers=1) def test_multiple_workers(self): self._test_run(workers=10) ================================================ FILE: test/retcodes_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2015-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import mock from helpers import LuigiTestCase, with_config import luigi import luigi.scheduler from luigi.cmdline import luigi_run class RetcodesTest(LuigiTestCase): def run_and_expect(self, joined_params, retcode, extra_args=["--local-scheduler", "--no-lock"]): with self.assertRaises(SystemExit) as cm: luigi_run((joined_params.split(" ") + extra_args)) self.assertEqual(cm.exception.code, retcode) def run_with_config(self, retcode_config, *args, **kwargs): with_config(dict(retcode=retcode_config))(self.run_and_expect)(*args, **kwargs) def test_task_failed(self): class FailingTask(luigi.Task): def run(self): raise ValueError() self.run_and_expect("FailingTask", 0) # Test default value to be 0 self.run_and_expect("FailingTask --retcode-task-failed 5", 5) self.run_with_config(dict(task_failed="3"), "FailingTask", 3) def test_missing_data(self): class MissingDataTask(luigi.ExternalTask): def complete(self): return False self.run_and_expect("MissingDataTask", 0) # Test default value to be 0 self.run_and_expect("MissingDataTask --retcode-missing-data 5", 5) self.run_with_config(dict(missing_data="3"), "MissingDataTask", 3) def test_already_running(self): class AlreadyRunningTask(luigi.Task): def run(self): pass old_func = luigi.scheduler.Scheduler.get_work def new_func(*args, **kwargs): kwargs["current_tasks"] = None old_func(*args, **kwargs) res = old_func(*args, **kwargs) res["running_tasks"][0]["worker"] = "not me :)" # Otherwise it will be filtered return res with mock.patch("luigi.scheduler.Scheduler.get_work", new_func): self.run_and_expect("AlreadyRunningTask", 0) # Test default value to be 0 self.run_and_expect("AlreadyRunningTask --retcode-already-running 5", 5) self.run_with_config(dict(already_running="3"), "AlreadyRunningTask", 3) def test_when_locked(self): def new_func(*args, **kwargs): return False with mock.patch("luigi.lock.acquire_for", new_func): self.run_and_expect("Task", 0, extra_args=["--local-scheduler"]) self.run_and_expect("Task --retcode-already-running 5", 5, extra_args=["--local-scheduler"]) self.run_with_config(dict(already_running="3"), "Task", 3, extra_args=["--local-scheduler"]) def test_failure_in_complete(self): class FailingComplete(luigi.Task): def complete(self): raise Exception class RequiringTask(luigi.Task): def requires(self): yield FailingComplete() self.run_and_expect("RequiringTask", 0) def test_failure_in_requires(self): class FailingRequires(luigi.Task): def requires(self): raise Exception self.run_and_expect("FailingRequires", 0) def test_validate_dependency_error(self): # requires() from RequiringTask expects a Task object class DependencyTask: pass class RequiringTask(luigi.Task): def requires(self): yield DependencyTask() self.run_and_expect("RequiringTask", 4) def test_task_limit(self): class TaskB(luigi.Task): def complete(self): return False class TaskA(luigi.Task): def requires(sefl): yield TaskB() class TaskLimitTest(luigi.Task): def requires(self): yield TaskA() self.run_and_expect("TaskLimitTest --worker-task-limit 2", 0) self.run_and_expect("TaskLimitTest --worker-task-limit 2 --retcode-scheduling-error 3", 3) def test_unhandled_exception(self): def new_func(*args, **kwargs): raise Exception() with mock.patch("luigi.worker.Worker.add", new_func): self.run_and_expect("Task", 4) self.run_and_expect("Task --retcode-unhandled-exception 2", 2) class TaskWithRequiredParam(luigi.Task): param = luigi.Parameter() self.run_and_expect("TaskWithRequiredParam --param hello", 0) self.run_and_expect("TaskWithRequiredParam", 4) def test_when_mixed_errors(self): class FailingTask(luigi.Task): def run(self): raise ValueError() class MissingDataTask(luigi.ExternalTask): def complete(self): return False class RequiringTask(luigi.Task): def requires(self): yield FailingTask() yield MissingDataTask() self.run_and_expect("RequiringTask --retcode-task-failed 4 --retcode-missing-data 5", 5) self.run_and_expect("RequiringTask --retcode-task-failed 7 --retcode-missing-data 6", 7) def test_unknown_reason(self): class TaskA(luigi.Task): def complete(self): return True class RequiringTask(luigi.Task): def requires(self): yield TaskA() def new_func(*args, **kwargs): return None with mock.patch("luigi.scheduler.Scheduler.add_task", new_func): self.run_and_expect("RequiringTask", 0) self.run_and_expect("RequiringTask --retcode-not-run 5", 5) """ Test that a task once crashing and then succeeding should be counted as no failure. """ def test_retry_sucess_task(self): class Foo(luigi.Task): run_count = 0 def run(self): self.run_count += 1 if self.run_count == 1: raise ValueError() def complete(self): return self.run_count > 0 self.run_and_expect("Foo --scheduler-retry-delay=0", 0) self.run_and_expect("Foo --scheduler-retry-delay=0 --retcode-task-failed=5", 0) self.run_with_config(dict(task_failed="3"), "Foo", 0) ================================================ FILE: test/rpc_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from helpers import unittest, with_config try: from unittest import mock except ImportError: import mock import socket from multiprocessing import Process, Queue import scheduler_api_test from server_test import ServerTestBase import luigi.rpc import luigi.server from luigi.scheduler import Scheduler class RemoteSchedulerTest(unittest.TestCase): def testUrlArgumentVariations(self): for url in ["http://zorg.com", "http://zorg.com/"]: for suffix in ["api/123", "/api/123"]: s = luigi.rpc.RemoteScheduler(url, 42) with mock.patch.object(s, "_fetcher") as fetcher: s._fetch(suffix, "{}") fetcher.fetch.assert_called_once_with("http://zorg.com/api/123", "{}", 42) def testUrlArgumentVariationsNotRoot(self): for url in ["http://zorg.com/subpath", "http://zorg.com/subpath/"]: for suffix in ["api/123", "/api/123"]: s = luigi.rpc.RemoteScheduler(url, 42) with mock.patch.object(s, "_fetcher") as fetcher: s._fetch(suffix, "{}") fetcher.fetch.assert_called_once_with("http://zorg.com/subpath/api/123", "{}", 42) def get_work(self, fetcher_side_effect): scheduler = luigi.rpc.RemoteScheduler("http://zorg.com", 42) scheduler._rpc_retry_wait = 1 # shorten wait time to speed up tests with mock.patch.object(scheduler, "_fetcher") as fetcher: fetcher.raises = socket.timeout, socket.gaierror fetcher.fetch.side_effect = fetcher_side_effect return scheduler.get_work("fake_worker") def test_retry_rpc_method(self): """ Tests that a call to a RPC method is re-tried 3 times. """ fetch_results = [socket.timeout, socket.timeout, '{"response":{}}'] self.assertEqual({}, self.get_work(fetch_results)) def test_retry_rpc_limited(self): """ Tests that a call to an RPC method fails after the third attempt """ fetch_results = [socket.timeout, socket.timeout, socket.timeout] self.assertRaises(luigi.rpc.RPCError, self.get_work, fetch_results) @mock.patch("luigi.rpc.logger") def test_log_rpc_retries_enabled(self, mock_logger): """ Tests that each retry of an RPC method is logged """ fetch_results = [socket.timeout, socket.timeout, '{"response":{}}'] self.get_work(fetch_results) self.assertEqual( [ mock.call.warning("Failed connecting to remote scheduler %r", "http://zorg.com", exc_info=True), mock.call.info("Retrying attempt 2 of 3 (max)"), mock.call.info("Wait for 1 seconds"), mock.call.warning("Failed connecting to remote scheduler %r", "http://zorg.com", exc_info=True), mock.call.info("Retrying attempt 3 of 3 (max)"), mock.call.info("Wait for 1 seconds"), ], mock_logger.mock_calls, ) @with_config({"core": {"rpc-log-retries": "false"}}) @mock.patch("luigi.rpc.logger") def test_log_rpc_retries_disabled(self, mock_logger): """ Tests that retries of an RPC method are not logged """ fetch_results = [socket.timeout, socket.timeout, socket.gaierror] try: self.get_work(fetch_results) self.fail("get_work should have thrown RPCError") except luigi.rpc.RPCError as e: self.assertTrue(isinstance(e.sub_exception, socket.gaierror)) self.assertEqual([], mock_logger.mock_calls) def test_get_work_retries_on_null(self): """ Tests that get_work will retry if the response is null """ fetch_results = ['{"response": null}', '{"response": {"pass": true}}'] self.assertEqual({"pass": True}, self.get_work(fetch_results)) def test_get_work_retries_on_null_limited(self): """ Tests that get_work will give up after the third null response """ fetch_results = ['{"response": null}'] * 3 + ['{"response": {}}'] self.assertRaises(luigi.rpc.RPCError, self.get_work, fetch_results) class RPCTest(scheduler_api_test.SchedulerApiTest, ServerTestBase): def get_app(self): conf = self.get_scheduler_config() sch = Scheduler(**conf) return luigi.server.app(sch) def setUp(self): super(RPCTest, self).setUp() self.sch = luigi.rpc.RemoteScheduler(self.get_url("")) self.sch._wait = lambda: None # disable test that doesn't work with remote scheduler def test_task_first_failure_time(self): pass def test_task_first_failure_time_remains_constant(self): pass def test_task_has_excessive_failures(self): pass def test_quadratic_behavior(self): """This would be too slow to run through network""" pass def test_get_work_speed(self): """This would be too slow to run through network""" pass class RequestsFetcherTest(ServerTestBase): def test_fork_changes_session(self): fetcher = luigi.rpc.RequestsFetcher() session = fetcher.session q = Queue() def check_session(q): fetcher.check_pid() # make sure that check_pid has changed out the session q.put(fetcher.session != session) p = Process(target=check_session, args=(q,)) p.start() p.join() self.assertTrue(q.get(), "the requests.Session should have changed in the new process") class URLLibFetcherTest(ServerTestBase): def test_url_with_basic_auth(self): fetcher = luigi.rpc.URLLibFetcher() # without password req = fetcher._create_request("http://user@localhost") self.assertTrue(req.has_header("Authorization")) self.assertEqual(req.get_header("Authorization"), "Basic dXNlcjo=") self.assertEqual(req.get_full_url(), "http://localhost") # empty password (same as above) req = fetcher._create_request("http://user:@localhost") self.assertTrue(req.has_header("Authorization")) self.assertEqual(req.get_header("Authorization"), "Basic dXNlcjo=") self.assertEqual(req.get_full_url(), "http://localhost") # with password req = fetcher._create_request("http://user:pass@localhost") self.assertTrue(req.has_header("Authorization")) self.assertEqual(req.get_header("Authorization"), "Basic dXNlcjpwYXNz") self.assertEqual(req.get_full_url(), "http://localhost") def test_url_without_basic_auth(self): fetcher = luigi.rpc.URLLibFetcher() req = fetcher._create_request("http://localhost") self.assertFalse(req.has_header("Authorization")) self.assertEqual(req.get_full_url(), "http://localhost") def test_body_encoding(self): fetcher = luigi.rpc.URLLibFetcher() # with body req = fetcher._create_request("http://localhost", body={"foo": "bar baz/test"}) self.assertEqual(req.data, b"foo=bar+baz%2Ftest") # without body req = fetcher._create_request("http://localhost") self.assertIsNone(req.data) ================================================ FILE: test/runtests.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import sys import warnings import pytest if __name__ == "__main__": with warnings.catch_warnings(): warnings.simplefilter("default") warnings.filterwarnings("ignore", message="(.*)outputs has no custom(.*)", category=UserWarning) sys.exit(pytest.main(sys.argv[1:])) ================================================ FILE: test/safe_extractor_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ Safe Extractor Test ============= Tests for the Safe Extractor class in luigi.safe_extractor module. """ import os import shutil import tarfile import tempfile import unittest from luigi.safe_extractor import SafeExtractor class TestSafeExtract(unittest.TestCase): """ Unit test class for testing the SafeExtractor module. """ def setUp(self): """Set up a temporary directory for test files.""" self.temp_dir = tempfile.mkdtemp() self.test_file_template = "test_file_{}.txt" self.tar_file_name = "test.tar" self.tar_file_name_with_traversal = f"traversal_{self.tar_file_name}" def tearDown(self): """Clean up the temporary directory after each test.""" shutil.rmtree(self.temp_dir) def create_test_tar(self, tar_path, file_count=1, with_traversal=False): """ Create a tar file containing test files. Args: tar_path (str): Path where the tar file will be created. file_count (int): Number of test files to include. with_traversal (bool): If True, creates a tar file with path traversal vulnerability. """ # Default content for the test files file_contents = [f"This is {self.test_file_template.format(i)}" for i in range(file_count)] with tarfile.open(tar_path, "w") as tar: for i in range(file_count): file_name = self.test_file_template.format(i) file_path = os.path.join(self.temp_dir, file_name) # Write content to each test file with open(file_path, "w") as f: f.write(file_contents[i]) # If path traversal is enabled, create malicious paths archive_name = f"../../{file_name}" if with_traversal else file_name # Add the file to the tar archive tar.add(file_path, arcname=archive_name) def verify_extracted_files(self, file_count): """ Verify that the correct files were extracted and their contents match expectations. Args: file_count (int): Number of files to verify. """ for i in range(file_count): file_name = self.test_file_template.format(i) file_path = os.path.join(self.temp_dir, file_name) # Check if the file exists self.assertTrue(os.path.exists(file_path), f"File {file_name} does not exist.") # Check if the file content is correct with open(file_path, "r") as f: content = f.read() expected_content = f"This is {file_name}" self.assertEqual(content, expected_content, f"Content mismatch in {file_name}.") def test_safe_extract(self): """Test normal safe extraction of tar files.""" tar_path = os.path.join(self.temp_dir, self.tar_file_name) # Create a tar file with 3 files self.create_test_tar(tar_path, file_count=3) # Initialize SafeExtractor and perform extraction extractor = SafeExtractor(self.temp_dir) extractor.safe_extract(tar_path) # Verify that all 3 files were extracted correctly self.verify_extracted_files(3) def test_safe_extract_with_traversal(self): """Test safe extraction for tar files with path traversal (should raise an error).""" tar_path = os.path.join(self.temp_dir, self.tar_file_name_with_traversal) # Create a tar file with a path traversal file self.create_test_tar(tar_path, file_count=1, with_traversal=True) # Initialize SafeExtractor and expect RuntimeError due to path traversal extractor = SafeExtractor(self.temp_dir) with self.assertRaises(RuntimeError): extractor.safe_extract(tar_path) if __name__ == "__main__": unittest.main() ================================================ FILE: test/scheduler_api_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import itertools import time import mock import pytest from helpers import unittest import luigi.notifications from luigi.scheduler import BATCH_RUNNING, DISABLED, DONE, FAILED, PENDING, RUNNING, UNKNOWN, UPSTREAM_RUNNING, Scheduler luigi.notifications.DEBUG = True WORKER = "myworker" @pytest.mark.scheduler class SchedulerApiTest(unittest.TestCase): def setUp(self): super(SchedulerApiTest, self).setUp() conf = self.get_scheduler_config() self.sch = Scheduler(**conf) self.time = time.time def get_scheduler_config(self): return { "retry_delay": 100, "remove_delay": 1000, "worker_disconnect_delay": 10, "disable_persist": 10, "disable_window": 10, "retry_count": 3, "disable_hard_timeout": 60 * 60, "stable_done_cooldown_secs": 0, } def tearDown(self): super(SchedulerApiTest, self).tearDown() if time.time != self.time: time.time = self.time def setTime(self, t): time.time = lambda: t def test_dep(self): self.sch.add_task(worker=WORKER, task_id="B", deps=("A",)) self.sch.add_task(worker=WORKER, task_id="A") self.assertEqual(self.sch.get_work(worker=WORKER)["task_id"], "A") self.sch.add_task(worker=WORKER, task_id="A", status=DONE) self.assertEqual(self.sch.get_work(worker=WORKER)["task_id"], "B") self.sch.add_task(worker=WORKER, task_id="B", status=DONE) self.assertEqual(self.sch.get_work(worker=WORKER)["task_id"], None) def test_failed_dep(self): self.sch.add_task(worker=WORKER, task_id="B", deps=("A",)) self.sch.add_task(worker=WORKER, task_id="A") self.assertEqual(self.sch.get_work(worker=WORKER)["task_id"], "A") self.sch.add_task(worker=WORKER, task_id="A", status=FAILED) self.assertEqual(self.sch.get_work(worker=WORKER)["task_id"], None) # can still wait and retry: TODO: do we want this? self.sch.add_task(worker=WORKER, task_id="A", status=DONE) self.assertEqual(self.sch.get_work(worker=WORKER)["task_id"], "B") self.sch.add_task(worker=WORKER, task_id="B", status=DONE) self.assertEqual(self.sch.get_work(worker=WORKER)["task_id"], None) def test_broken_dep(self): self.sch.add_task(worker=WORKER, task_id="B", deps=("A",)) self.sch.add_task(worker=WORKER, task_id="A", runnable=False) self.assertEqual(self.sch.get_work(worker=WORKER)["task_id"], None) # can still wait and retry: TODO: do we want this? self.sch.add_task(worker=WORKER, task_id="A", status=DONE) self.assertEqual(self.sch.get_work(worker=WORKER)["task_id"], "B") self.sch.add_task(worker=WORKER, task_id="B", status=DONE) self.assertEqual(self.sch.get_work(worker=WORKER)["task_id"], None) def test_two_workers(self): # Worker X wants to build A -> B # Worker Y wants to build A -> C self.sch.add_task(worker="X", task_id="A") self.sch.add_task(worker="Y", task_id="A") self.sch.add_task(task_id="B", deps=("A",), worker="X") self.sch.add_task(task_id="C", deps=("A",), worker="Y") self.assertEqual(self.sch.get_work(worker="X")["task_id"], "A") self.assertEqual(self.sch.get_work(worker="Y")["task_id"], None) # Worker Y is pending on A to be done self.sch.add_task(worker="X", task_id="A", status=DONE) self.assertEqual(self.sch.get_work(worker="Y")["task_id"], "C") self.assertEqual(self.sch.get_work(worker="X")["task_id"], "B") def test_status_wont_override(self): # Worker X is running A # Worker Y wants to override the status to UNKNOWN (e.g. complete is throwing an exception) self.sch.add_task(worker="X", task_id="A") self.assertEqual(self.sch.get_work(worker="X")["task_id"], "A") self.sch.add_task(worker="Y", task_id="A", status=UNKNOWN) self.assertEqual({"A"}, set(self.sch.task_list(RUNNING, "").keys())) def test_retry(self): # Try to build A but fails, will retry after 100s self.setTime(0) self.sch.add_task(worker=WORKER, task_id="A") self.assertEqual(self.sch.get_work(worker=WORKER)["task_id"], "A") self.sch.add_task(worker=WORKER, task_id="A", status=FAILED) for t in range(100): self.setTime(t) self.assertEqual(self.sch.get_work(worker=WORKER)["task_id"], None) self.sch.ping(worker=WORKER) if t % 10 == 0: self.sch.prune() self.setTime(101) self.sch.prune() self.assertEqual(self.sch.get_work(worker=WORKER)["task_id"], "A") def test_resend_task(self): self.sch.add_task(worker=WORKER, task_id="A") self.sch.add_task(worker=WORKER, task_id="B") for _ in range(10): self.assertEqual("A", self.sch.get_work(worker=WORKER, current_tasks=[])["task_id"]) self.assertEqual("B", self.sch.get_work(worker=WORKER, current_tasks=["A"])["task_id"]) def test_resend_multiple_tasks(self): self.sch.add_task(worker=WORKER, task_id="A") self.sch.add_task(worker=WORKER, task_id="B") self.sch.add_task(worker=WORKER, task_id="C") # get A and B running self.assertEqual("A", self.sch.get_work(worker=WORKER)["task_id"]) self.assertEqual("B", self.sch.get_work(worker=WORKER)["task_id"]) for _ in range(10): self.assertEqual("A", self.sch.get_work(worker=WORKER, current_tasks=[])["task_id"]) self.assertEqual("A", self.sch.get_work(worker=WORKER, current_tasks=["B"])["task_id"]) self.assertEqual("B", self.sch.get_work(worker=WORKER, current_tasks=["A"])["task_id"]) self.assertEqual("C", self.sch.get_work(worker=WORKER, current_tasks=["A", "B"])["task_id"]) def test_disconnect_running(self): # X and Y wants to run A. # X starts but does not report back. Y does. # After some timeout, Y will build it instead self.setTime(0) self.sch.add_task(task_id="A", worker="X") self.sch.add_task(task_id="A", worker="Y") self.assertEqual(self.sch.get_work(worker="X")["task_id"], "A") for t in range(200): self.setTime(t) self.sch.ping(worker="Y") if t % 10 == 0: self.sch.prune() self.assertEqual(self.sch.get_work(worker="Y")["task_id"], "A") def test_get_work_single_batch_item(self): self.sch.add_task_batcher(worker=WORKER, task_family="A", batched_args=["a"]) self.sch.add_task(worker=WORKER, task_id="A_a_1", family="A", params={"a": "1"}, batchable=True) response = self.sch.get_work(worker=WORKER) self.assertEqual("A_a_1", response["task_id"]) param_values = response["task_params"].values() self.assertTrue(not any(isinstance(param, list)) for param in param_values) def test_get_work_multiple_batch_items(self): self.sch.add_task_batcher(worker=WORKER, task_family="A", batched_args=["a"]) self.sch.add_task(worker=WORKER, task_id="A_a_1", family="A", params={"a": "1"}, batchable=True) self.sch.add_task(worker=WORKER, task_id="A_a_2", family="A", params={"a": "2"}, batchable=True) self.sch.add_task(worker=WORKER, task_id="A_a_3", family="A", params={"a": "3"}, batchable=True) response = self.sch.get_work(worker=WORKER) self.assertIsNone(response["task_id"]) self.assertEqual({"a": ["1", "2", "3"]}, response["task_params"]) self.assertEqual("A", response["task_family"]) def test_batch_time_running(self): self.setTime(1234) self.sch.add_task_batcher(worker=WORKER, task_family="A", batched_args=["a"]) self.sch.add_task(worker=WORKER, task_id="A_a_1", family="A", params={"a": "1"}, batchable=True) self.sch.add_task(worker=WORKER, task_id="A_a_2", family="A", params={"a": "2"}, batchable=True) self.sch.add_task(worker=WORKER, task_id="A_a_3", family="A", params={"a": "3"}, batchable=True) self.sch.get_work(worker=WORKER) for task in self.sch.task_list().values(): self.assertEqual(1234, task["time_running"]) def test_batch_ignore_items_not_ready(self): self.sch.add_task_batcher(worker=WORKER, task_family="A", batched_args=["a"]) self.sch.add_task(worker=WORKER, task_id="A_a_1", family="A", params={"a": "1"}, batchable=True) self.sch.add_task(worker=WORKER, task_id="A_a_2", family="A", params={"a": "2"}, deps=["NOT_DONE"], batchable=True) self.sch.add_task(worker=WORKER, task_id="A_a_3", family="A", params={"a": "3"}, deps=["DONE"], batchable=True) self.sch.add_task(worker=WORKER, task_id="A_a_4", family="A", params={"a": "4"}, deps=["DONE"], batchable=True) self.sch.add_task(worker=WORKER, task_id="A_a_5", family="A", params={"a": "5"}, deps=["NOT_DONE"], batchable=True) self.sch.add_task(worker=WORKER, task_id="NOT_DONE", runnable=False) self.sch.add_task(worker=WORKER, task_id="DONE", status=DONE) response = self.sch.get_work(worker=WORKER) self.assertIsNone(response["task_id"]) self.assertEqual({"a": ["1", "3", "4"]}, response["task_params"]) self.assertEqual("A", response["task_family"]) def test_batch_ignore_first_item_not_ready(self): self.sch.add_task_batcher(worker=WORKER, task_family="A", batched_args=["a"]) self.sch.add_task(worker=WORKER, task_id="A_a_1", family="A", params={"a": "1"}, deps=["NOT_DONE"], batchable=True) self.sch.add_task(worker=WORKER, task_id="A_a_2", family="A", params={"a": "2"}, deps=["DONE"], batchable=True) self.sch.add_task(worker=WORKER, task_id="A_a_3", family="A", params={"a": "3"}, deps=["DONE"], batchable=True) self.sch.add_task(worker=WORKER, task_id="NOT_DONE", runnable=False) self.sch.add_task(worker=WORKER, task_id="DONE", status=DONE) response = self.sch.get_work(worker=WORKER) self.assertIsNone(response["task_id"]) self.assertEqual({"a": ["2", "3"]}, response["task_params"]) self.assertEqual("A", response["task_family"]) def test_get_work_with_batch_items_with_resources(self): self.sch.add_task_batcher(worker=WORKER, task_family="A", batched_args=["a"]) self.sch.add_task(worker=WORKER, task_id="A_a_1", family="A", params={"a": "1"}, batchable=True, resources={"r1": 1}) self.sch.add_task(worker=WORKER, task_id="A_a_2", family="A", params={"a": "2"}, batchable=True, resources={"r1": 1}) self.sch.add_task(worker=WORKER, task_id="A_a_3", family="A", params={"a": "3"}, batchable=True, resources={"r1": 1}) response = self.sch.get_work(worker=WORKER) self.assertIsNone(response["task_id"]) self.assertEqual({"a": ["1", "2", "3"]}, response["task_params"]) self.assertEqual("A", response["task_family"]) def test_get_work_limited_batch_size(self): self.sch.add_task_batcher(worker=WORKER, task_family="A", batched_args=["a"], max_batch_size=2) self.sch.add_task(worker=WORKER, task_id="A_a_1", family="A", params={"a": "1"}, batchable=True, priority=1) self.sch.add_task(worker=WORKER, task_id="A_a_2", family="A", params={"a": "2"}, batchable=True) self.sch.add_task(worker=WORKER, task_id="A_a_3", family="A", params={"a": "3"}, batchable=True, priority=2) response = self.sch.get_work(worker=WORKER) self.assertIsNone(response["task_id"]) self.assertEqual({"a": ["3", "1"]}, response["task_params"]) self.assertEqual("A", response["task_family"]) response2 = self.sch.get_work(worker=WORKER) self.assertEqual("A_a_2", response2["task_id"]) def test_get_work_do_not_batch_non_batchable_item(self): self.sch.add_task_batcher(worker=WORKER, task_family="A", batched_args=["a"]) self.sch.add_task(worker=WORKER, task_id="A_a_1", family="A", params={"a": "1"}, batchable=True, priority=1) self.sch.add_task(worker=WORKER, task_id="A_a_2", family="A", params={"a": "2"}, batchable=True) self.sch.add_task(worker=WORKER, task_id="A_a_3", family="A", params={"a": "3"}, batchable=False, priority=2) response = self.sch.get_work(worker=WORKER) self.assertEqual("A_a_3", response["task_id"]) response2 = self.sch.get_work(worker=WORKER) self.assertIsNone(response2["task_id"]) self.assertEqual({"a": ["1", "2"]}, response2["task_params"]) self.assertEqual("A", response2["task_family"]) def test_get_work_group_on_non_batch_params(self): self.sch.add_task_batcher(worker=WORKER, task_family="A", batched_args=["b"]) for a, b, c in itertools.product((1, 2), repeat=3): self.sch.add_task( worker=WORKER, task_id="A_%i_%i_%i" % (a, b, c), family="A", params={"a": str(a), "b": str(b), "c": str(c)}, batchable=True, priority=9 * a + 3 * c + b, ) for a, c in [("2", "2"), ("2", "1"), ("1", "2"), ("1", "1")]: response = self.sch.get_work(worker=WORKER) self.assertIsNone(response["task_id"]) self.assertEqual({"a": a, "b": ["2", "1"], "c": c}, response["task_params"]) self.assertEqual("A", response["task_family"]) def test_get_work_multiple_batched_params(self): self.sch.add_task_batcher(worker=WORKER, task_family="A", batched_args=["a", "b"]) self.sch.add_task(worker=WORKER, task_id="A_1_1", family="A", params={"a": "1", "b": "1"}, priority=1, batchable=True) self.sch.add_task(worker=WORKER, task_id="A_1_2", family="A", params={"a": "1", "b": "2"}, priority=2, batchable=True) self.sch.add_task(worker=WORKER, task_id="A_2_1", family="A", params={"a": "2", "b": "1"}, priority=3, batchable=True) self.sch.add_task(worker=WORKER, task_id="A_2_2", family="A", params={"a": "2", "b": "2"}, priority=4, batchable=True) response = self.sch.get_work(worker=WORKER) self.assertIsNone(response["task_id"]) expected_params = { "a": ["2", "2", "1", "1"], "b": ["2", "1", "2", "1"], } self.assertEqual(expected_params, response["task_params"]) def test_get_work_with_unbatched_worker_on_batched_task(self): self.sch.add_task_batcher(worker="batcher", task_family="A", batched_args=["a"]) for i in range(5): self.sch.add_task(worker=WORKER, task_id="A_%i" % i, family="A", params={"a": str(i)}, priority=i, batchable=False) self.sch.add_task(worker="batcher", task_id="A_%i" % i, family="A", params={"a": str(i)}, priority=i, batchable=True) self.assertEqual("A_4", self.sch.get_work(worker=WORKER)["task_id"]) batch_response = self.sch.get_work(worker="batcher") self.assertIsNone(batch_response["task_id"]) self.assertEqual({"a": ["3", "2", "1", "0"]}, batch_response["task_params"]) def test_batched_tasks_become_batch_running(self): self.sch.add_task_batcher(worker=WORKER, task_family="A", batched_args=["a"]) self.sch.add_task(worker=WORKER, task_id="A_1", family="A", params={"a": 1}, batchable=True) self.sch.add_task(worker=WORKER, task_id="A_2", family="A", params={"a": 2}, batchable=True) self.sch.get_work(worker=WORKER) self.assertEqual({"A_1", "A_2"}, set(self.sch.task_list("BATCH_RUNNING", "").keys())) def test_downstream_jobs_from_batch_running_have_upstream_running_status(self): self.sch.add_task_batcher(worker=WORKER, task_family="A", batched_args=["a"]) self.sch.add_task(worker=WORKER, task_id="A_1", family="A", params={"a": 1}, batchable=True) self.sch.add_task(worker=WORKER, task_id="A_2", family="A", params={"a": 2}, batchable=True) self.sch.get_work(worker=WORKER) self.assertEqual({"A_1", "A_2"}, set(self.sch.task_list("BATCH_RUNNING", "").keys())) self.sch.add_task(worker=WORKER, task_id="B", deps=["A_1"]) self.assertEqual({"B"}, set(self.sch.task_list(PENDING, UPSTREAM_RUNNING).keys())) def test_set_batch_runner_new_task(self): self.sch.add_task_batcher(worker=WORKER, task_family="A", batched_args=["a"]) self.sch.add_task(worker=WORKER, task_id="A_1", family="A", params={"a": "1"}, batchable=True) self.sch.add_task(worker=WORKER, task_id="A_2", family="A", params={"a": "2"}, batchable=True) response = self.sch.get_work(worker=WORKER) batch_id = response["batch_id"] self.sch.add_task(worker=WORKER, task_id="A_1_2", task_family="A", params={"a": "1,2"}, batch_id=batch_id, status="RUNNING") self.assertEqual({"A_1", "A_2"}, set(self.sch.task_list("BATCH_RUNNING", "").keys())) self.assertEqual({"A_1_2"}, set(self.sch.task_list("RUNNING", "").keys())) self.sch.add_task(worker=WORKER, task_id="A_1_2", status=DONE) self.assertEqual({"A_1", "A_2", "A_1_2"}, set(self.sch.task_list(DONE, "").keys())) def test_set_batch_runner_max(self): self.sch.add_task_batcher(worker=WORKER, task_family="A", batched_args=["a"]) self.sch.add_task(worker=WORKER, task_id="A_1", family="A", params={"a": "1"}, batchable=True) self.sch.add_task(worker=WORKER, task_id="A_2", family="A", params={"a": "2"}, batchable=True) response = self.sch.get_work(worker=WORKER) batch_id = response["batch_id"] self.sch.add_task(worker=WORKER, task_id="A_2", task_family="A", params={"a": "2"}, batch_id=batch_id, status="RUNNING") self.assertEqual({"A_1"}, set(self.sch.task_list("BATCH_RUNNING", "").keys())) self.assertEqual({"A_2"}, set(self.sch.task_list("RUNNING", "").keys())) self.sch.add_task(worker=WORKER, task_id="A_2", status=DONE) self.assertEqual({"A_1", "A_2"}, set(self.sch.task_list(DONE, "").keys())) def _start_simple_batch(self, use_max=False, mark_running=True, resources=None): self.sch.add_task_batcher(worker=WORKER, task_family="A", batched_args=["a"]) self.sch.add_task(worker=WORKER, task_id="A_1", family="A", params={"a": "1"}, batchable=True, resources=resources) self.sch.add_task(worker=WORKER, task_id="A_2", family="A", params={"a": "2"}, batchable=True, resources=resources) response = self.sch.get_work(worker=WORKER) if mark_running: batch_id = response["batch_id"] task_id, params = ("A_2", {"a": "2"}) if use_max else ("A_1_2", {"a": "1,2"}) self.sch.add_task(worker=WORKER, task_id=task_id, task_family="A", params=params, batch_id=batch_id, status="RUNNING") return batch_id, task_id, params def test_set_batch_runner_retry(self): batch_id, task_id, params = self._start_simple_batch() self.sch.add_task(worker=WORKER, task_id=task_id, task_family="A", params=params, batch_id=batch_id, status="RUNNING") self.assertEqual({task_id}, set(self.sch.task_list("RUNNING", "").keys())) self.assertEqual({"A_1", "A_2"}, set(self.sch.task_list(BATCH_RUNNING, "").keys())) def test_set_batch_runner_multiple_retries(self): batch_id, task_id, params = self._start_simple_batch() for _ in range(3): self.sch.add_task(worker=WORKER, task_id=task_id, task_family="A", params=params, batch_id=batch_id, status="RUNNING") self.assertEqual({task_id}, set(self.sch.task_list("RUNNING", "").keys())) self.assertEqual({"A_1", "A_2"}, set(self.sch.task_list(BATCH_RUNNING, "").keys())) def test_batch_fail(self): self._start_simple_batch() self.sch.add_task(worker=WORKER, task_id="A_1_2", status=FAILED, expl="bad failure") task_ids = {"A_1", "A_2"} self.assertEqual(task_ids, set(self.sch.task_list(FAILED, "").keys())) for task_id in task_ids: expl = self.sch.fetch_error(task_id)["error"] self.assertEqual("bad failure", expl) def test_batch_fail_max(self): self._start_simple_batch(use_max=True) self.sch.add_task(worker=WORKER, task_id="A_2", status=FAILED, expl="bad max failure") task_ids = {"A_1", "A_2"} self.assertEqual(task_ids, set(self.sch.task_list(FAILED, "").keys())) for task_id in task_ids: response = self.sch.fetch_error(task_id) self.assertEqual("bad max failure", response["error"]) def test_batch_fail_from_dead_worker(self): self.setTime(1) self._start_simple_batch() self.setTime(601) self.sch.prune() self.assertEqual({"A_1", "A_2"}, set(self.sch.task_list(FAILED, "").keys())) def test_batch_fail_max_from_dead_worker(self): self.setTime(1) self._start_simple_batch(use_max=True) self.setTime(601) self.sch.prune() self.assertEqual({"A_1", "A_2"}, set(self.sch.task_list(FAILED, "").keys())) def test_batch_fail_from_dead_worker_without_running(self): self.setTime(1) self._start_simple_batch(mark_running=False) self.setTime(601) self.sch.prune() self.assertEqual({"A_1", "A_2"}, set(self.sch.task_list(FAILED, "").keys())) def test_batch_update_status(self): self._start_simple_batch() self.sch.set_task_status_message("A_1_2", "test message") for task_id in ("A_1", "A_2", "A_1_2"): self.assertEqual("test message", self.sch.get_task_status_message(task_id)["statusMessage"]) def test_batch_update_progress(self): self._start_simple_batch() self.sch.set_task_progress_percentage("A_1_2", 30) for task_id in ("A_1", "A_2", "A_1_2"): self.assertEqual(30, self.sch.get_task_progress_percentage(task_id)["progressPercentage"]) def test_batch_decrease_resources(self): self.sch.update_resources(x=3) self._start_simple_batch(resources={"x": 3}) self.sch.decrease_running_task_resources("A_1_2", {"x": 1}) for task_id in ("A_1", "A_2", "A_1_2"): self.assertEqual(2, self.sch.get_running_task_resources(task_id)["resources"]["x"]) def test_batch_tracking_url(self): self._start_simple_batch() self.sch.add_task(worker=WORKER, task_id="A_1_2", tracking_url="http://test.tracking.url/") tasks = self.sch.task_list("", "") for task_id in ("A_1", "A_2", "A_1_2"): self.assertEqual("http://test.tracking.url/", tasks[task_id]["tracking_url"]) def test_finish_batch(self): self._start_simple_batch() self.sch.add_task(worker=WORKER, task_id="A_1_2", status=DONE) self.assertEqual({"A_1", "A_2", "A_1_2"}, set(self.sch.task_list(DONE, "").keys())) def test_reschedule_max_batch(self): self.sch.add_task_batcher(worker=WORKER, task_family="A", batched_args=["a"]) self.sch.add_task(worker=WORKER, task_id="A_1", family="A", params={"a": "1"}, batchable=True) self.sch.add_task(worker=WORKER, task_id="A_2", family="A", params={"a": "2"}, batchable=True) response = self.sch.get_work(worker=WORKER) batch_id = response["batch_id"] self.sch.add_task(worker=WORKER, task_id="A_2", task_family="A", params={"a": "2"}, batch_id=batch_id, status="RUNNING") self.sch.add_task(worker=WORKER, task_id="A_2", status=DONE) self.sch.add_task(worker=WORKER, task_id="A_2", task_family="A", params={"a": "2"}, batchable=True) self.assertEqual({"A_2"}, set(self.sch.task_list(PENDING, "").keys())) self.assertEqual({"A_1"}, set(self.sch.task_list(DONE, "").keys())) def test_resend_batch_on_get_work_retry(self): self.sch.add_task_batcher(worker=WORKER, task_family="A", batched_args=["a"]) self.sch.add_task(worker=WORKER, task_id="A_1", family="A", params={"a": "1"}, batchable=True) self.sch.add_task(worker=WORKER, task_id="A_2", family="A", params={"a": "2"}, batchable=True) response = self.sch.get_work(worker=WORKER) response2 = self.sch.get_work(worker=WORKER, current_tasks=()) self.assertEqual(response["task_id"], response2["task_id"]) self.assertEqual(response["task_family"], response2.get("task_family")) self.assertEqual(response["task_params"], response2.get("task_params")) def test_resend_batch_runner_on_get_work_retry(self): self._start_simple_batch() get_work = self.sch.get_work(worker=WORKER, current_tasks=()) self.assertEqual("A_1_2", get_work["task_id"]) def test_resend_max_batch_runner_on_get_work_retry(self): self._start_simple_batch(use_max=True) get_work = self.sch.get_work(worker=WORKER, current_tasks=()) self.assertEqual("A_2", get_work["task_id"]) def test_do_not_resend_batch_runner_on_get_work(self): self._start_simple_batch() get_work = self.sch.get_work(worker=WORKER, current_tasks=("A_1_2",)) self.assertIsNone(get_work["task_id"]) def test_do_not_resend_max_batch_runner_on_get_work(self): self._start_simple_batch(use_max=True) get_work = self.sch.get_work(worker=WORKER, current_tasks=("A_2",)) self.assertIsNone(get_work["task_id"]) def test_rescheduled_batch_running_tasks_stay_batch_running_before_runner(self): self.sch.add_task_batcher(worker=WORKER, task_family="A", batched_args=["a"]) self.sch.add_task(worker=WORKER, task_id="A_1", family="A", params={"a": "1"}, batchable=True) self.sch.add_task(worker=WORKER, task_id="A_2", family="A", params={"a": "2"}, batchable=True) self.sch.get_work(worker=WORKER) self.sch.add_task(worker=WORKER, task_id="A_1", family="A", params={"a": "1"}, batchable=True) self.sch.add_task(worker=WORKER, task_id="A_2", family="A", params={"a": "2"}, batchable=True) self.assertEqual({"A_1", "A_2"}, set(self.sch.task_list(BATCH_RUNNING, "").keys())) def test_rescheduled_batch_running_tasks_stay_batch_running_after_runner(self): self._start_simple_batch() self.sch.add_task(worker=WORKER, task_id="A_1", family="A", params={"a": "1"}, batchable=True) self.sch.add_task(worker=WORKER, task_id="A_2", family="A", params={"a": "2"}, batchable=True) self.assertEqual({"A_1", "A_2"}, set(self.sch.task_list(BATCH_RUNNING, "").keys())) def test_disabled_batch_running_tasks_stay_batch_running_before_runner(self): self.sch.add_task_batcher(worker=WORKER, task_family="A", batched_args=["a"]) self.sch.add_task(worker=WORKER, task_id="A_1", family="A", params={"a": "1"}, batchable=True) self.sch.add_task(worker=WORKER, task_id="A_2", family="A", params={"a": "2"}, batchable=True) self.sch.get_work(worker=WORKER) self.sch.add_task(worker=WORKER, task_id="A_1", family="A", params={"a": "1"}, batchable=True, status=DISABLED) self.sch.add_task(worker=WORKER, task_id="A_2", family="A", params={"a": "2"}, batchable=True, status=DISABLED) self.assertEqual({"A_1", "A_2"}, set(self.sch.task_list(BATCH_RUNNING, "").keys())) def test_get_work_returns_batch_task_id_list(self): self.sch.add_task_batcher(worker=WORKER, task_family="A", batched_args=["a"]) self.sch.add_task(worker=WORKER, task_id="A_1", family="A", params={"a": "1"}, batchable=True) self.sch.add_task(worker=WORKER, task_id="A_2", family="A", params={"a": "2"}, batchable=True) response = self.sch.get_work(worker=WORKER) self.assertEqual({"A_1", "A_2"}, set(response["batch_task_ids"])) def test_disabled_batch_running_tasks_stay_batch_running_after_runner(self): self._start_simple_batch() self.sch.add_task(worker=WORKER, task_id="A_1", family="A", params={"a": "1"}, batchable=True, status=DISABLED) self.sch.add_task(worker=WORKER, task_id="A_2", family="A", params={"a": "2"}, batchable=True, status=DISABLED) self.assertEqual({"A_1", "A_2"}, set(self.sch.task_list(BATCH_RUNNING, "").keys())) def test_do_not_overwrite_tracking_url_while_running(self): self.sch.add_task(task_id="A", worker="X", status="RUNNING", tracking_url="trackme") self.assertEqual("trackme", self.sch.task_list("RUNNING", "")["A"]["tracking_url"]) # not wiped out by another working scheduling as pending self.sch.add_task(task_id="A", worker="Y", status="PENDING") self.assertEqual("trackme", self.sch.task_list("RUNNING", "")["A"]["tracking_url"]) def test_do_update_tracking_url_while_running(self): self.sch.add_task(task_id="A", worker="X", status="RUNNING", tracking_url="trackme") self.assertEqual("trackme", self.sch.task_list("RUNNING", "")["A"]["tracking_url"]) self.sch.add_task(task_id="A", worker="X", status="RUNNING", tracking_url="stage_2") self.assertEqual("stage_2", self.sch.task_list("RUNNING", "")["A"]["tracking_url"]) def test_keep_tracking_url_on_done_and_fail(self): for status in ("DONE", "FAILED"): self.sch.add_task(task_id="A", worker="X", status="RUNNING", tracking_url="trackme") self.assertEqual("trackme", self.sch.task_list("RUNNING", "")["A"]["tracking_url"]) self.sch.add_task(task_id="A", worker="X", status=status) self.assertEqual("trackme", self.sch.task_list(status, "")["A"]["tracking_url"]) def test_drop_tracking_url_when_rescheduled_while_not_running(self): for status in ("DONE", "FAILED", "PENDING"): self.sch.add_task(task_id="A", worker="X", status=status, tracking_url="trackme") self.assertEqual("trackme", self.sch.task_list(status, "")["A"]["tracking_url"]) self.sch.add_task(task_id="A", worker="Y", status="PENDING") self.assertIsNone(self.sch.task_list("PENDING", "")["A"]["tracking_url"]) def test_reset_tracking_url_on_new_run(self): self.sch.add_task(task_id="A", worker="X", status="PENDING", tracking_url="trackme") self.assertEqual("trackme", self.sch.task_list("PENDING", "")["A"]["tracking_url"]) self.sch.add_task(task_id="A", worker="Y", status="RUNNING") self.assertIsNone(self.sch.task_list("RUNNING", "")["A"]["tracking_url"]) def test_remove_dep(self): # X schedules A -> B, A is broken # Y schedules C -> B: this should remove A as a dep of B self.sch.add_task(task_id="A", worker="X", runnable=False) self.sch.add_task(task_id="B", deps=("A",), worker="X") # X can't build anything self.assertEqual(self.sch.get_work(worker="X")["task_id"], None) self.sch.add_task(task_id="B", deps=("C",), worker="Y") # should reset dependencies for A self.sch.add_task(task_id="C", worker="Y", status=DONE) self.assertEqual(self.sch.get_work(worker="Y")["task_id"], "B") def test_start_time(self): self.setTime(100) self.sch.add_task(worker=WORKER, task_id="A") self.setTime(200) self.sch.add_task(worker=WORKER, task_id="A") self.sch.add_task(worker=WORKER, task_id="A", status=DONE) self.assertEqual(100, self.sch.task_list(DONE, "")["A"]["start_time"]) def test_last_updated_does_not_change_with_same_status_update(self): for t, status in ((100, PENDING), (300, DONE), (500, DISABLED)): self.setTime(t) self.sch.add_task(worker=WORKER, task_id="A", status=status) self.assertEqual(t, self.sch.task_list(status, "")["A"]["last_updated"]) self.setTime(t + 100) self.sch.add_task(worker=WORKER, task_id="A", status=status) self.assertEqual(t, self.sch.task_list(status, "")["A"]["last_updated"]) def test_last_updated_shows_running_start(self): self.setTime(100) self.sch.add_task(worker=WORKER, task_id="A", status=PENDING) self.assertEqual(100, self.sch.task_list(PENDING, "")["A"]["last_updated"]) self.setTime(200) self.assertEqual("A", self.sch.get_work(worker=WORKER)["task_id"]) self.assertEqual(200, self.sch.task_list("RUNNING", "")["A"]["last_updated"]) self.setTime(300) self.sch.add_task(worker=WORKER, task_id="A", status=PENDING) self.assertEqual(200, self.sch.task_list("RUNNING", "")["A"]["last_updated"]) def test_last_updated_with_failure_and_recovery(self): self.setTime(100) self.sch.add_task(worker=WORKER, task_id="A") self.assertEqual("A", self.sch.get_work(worker=WORKER)["task_id"]) self.setTime(200) self.sch.add_task(worker=WORKER, task_id="A", status=FAILED) self.assertEqual(200, self.sch.task_list(FAILED, "")["A"]["last_updated"]) self.setTime(1000) self.sch.prune() self.assertEqual(1000, self.sch.task_list(PENDING, "")["A"]["last_updated"]) def test_timeout(self): # A bug that was earlier present when restarting the same flow self.setTime(0) self.sch.add_task(task_id="A", worker="X") self.assertEqual(self.sch.get_work(worker="X")["task_id"], "A") self.setTime(10000) self.sch.add_task(task_id="A", worker="Y") # Will timeout X but not schedule A for removal for i in range(2000): self.setTime(10000 + i) self.sch.ping(worker="Y") self.sch.add_task(task_id="A", status=DONE, worker="Y") # This used to raise an exception since A was removed def test_disallowed_state_changes(self): # Test that we can not schedule an already running task t = "A" self.sch.add_task(task_id=t, worker="X") self.assertEqual(self.sch.get_work(worker="X")["task_id"], t) self.sch.add_task(task_id=t, worker="Y") self.assertEqual(self.sch.get_work(worker="Y")["task_id"], None) def test_two_worker_info(self): # Make sure the scheduler returns info that some other worker is running task A self.sch.add_task(worker="X", task_id="A") self.sch.add_task(worker="Y", task_id="A") self.assertEqual(self.sch.get_work(worker="X")["task_id"], "A") r = self.sch.get_work(worker="Y") self.assertEqual(r["task_id"], None) # Worker Y is pending on A to be done s = r["running_tasks"][0] self.assertEqual(s["task_id"], "A") self.assertEqual(s["worker"], "X") def test_assistant_get_work(self): self.sch.add_task(worker="X", task_id="A") self.sch.add_worker("Y", []) self.assertEqual(self.sch.get_work(worker="Y", assistant=True)["task_id"], "A") # check that the scheduler recognizes tasks as running running_tasks = self.sch.task_list("RUNNING", "") self.assertEqual(len(running_tasks), 1) self.assertEqual(list(running_tasks.keys()), ["A"]) self.assertEqual(running_tasks["A"]["worker_running"], "Y") def test_assistant_get_work_external_task(self): self.sch.add_task(worker="X", task_id="A", runnable=False) self.assertTrue(self.sch.get_work(worker="Y", assistant=True)["task_id"] is None) def test_task_fails_when_assistant_dies(self): self.setTime(0) self.sch.add_task(worker="X", task_id="A") self.sch.add_worker("Y", []) self.assertEqual(self.sch.get_work(worker="Y", assistant=True)["task_id"], "A") self.assertEqual(list(self.sch.task_list("RUNNING", "").keys()), ["A"]) # Y dies for 50 seconds, X stays alive self.setTime(50) self.sch.ping(worker="X") self.assertEqual(list(self.sch.task_list("FAILED", "").keys()), ["A"]) def test_prune_with_live_assistant(self): self.setTime(0) self.sch.add_task(worker="X", task_id="A") self.sch.get_work(worker="Y", assistant=True) self.sch.add_task(worker="Y", task_id="A", status=DONE, assistant=True) # worker X stops communicating, A should be marked for removal self.setTime(600) self.sch.ping(worker="Y") self.sch.prune() # A will now be pruned self.setTime(2000) self.sch.prune() self.assertFalse(list(self.sch.task_list("", ""))) def test_re_enable_failed_task_assistant(self): self.setTime(0) self.sch.add_worker("X", [("assistant", True)]) self.sch.add_task(worker="X", task_id="A", status=FAILED, assistant=True) # should be failed now self.assertEqual(FAILED, self.sch.task_list("", "")["A"]["status"]) # resets to PENDING after 100 seconds self.setTime(101) self.sch.ping(worker="X") # worker still alive self.assertEqual("PENDING", self.sch.task_list("", "")["A"]["status"]) def test_assistant_doesnt_keep_alive_task(self): self.setTime(0) self.sch.add_task(worker="X", task_id="A") self.assertEqual("A", self.sch.get_work(worker="X")["task_id"]) self.sch.add_worker("Y", {"assistant": True}) remove_delay = self.get_scheduler_config()["remove_delay"] + 1.0 self.setTime(remove_delay) self.sch.ping(worker="Y") self.sch.prune() self.assertEqual(["A"], list(self.sch.task_list(status="FAILED", upstream_status="").keys())) self.assertEqual(["A"], list(self.sch.task_list(status="", upstream_status="").keys())) self.setTime(2 * remove_delay) self.sch.ping(worker="Y") self.sch.prune() self.assertEqual([], list(self.sch.task_list(status="", upstream_status="").keys())) def test_assistant_request_runnable_task(self): """ Test that an assistant gets a task despite it havent registered for it """ self.setTime(0) self.sch.add_task(worker="X", task_id="A", runnable=True) self.setTime(600) self.sch.prune() self.assertEqual("A", self.sch.get_work(worker="Y", assistant=True)["task_id"]) def test_assistant_request_external_task(self): self.sch.add_task(worker="X", task_id="A", runnable=False) self.assertIsNone(self.sch.get_work(worker="Y", assistant=True)["task_id"]) def _test_prune_done_tasks(self, expected=None): self.setTime(0) self.sch.add_task(worker=WORKER, task_id="A", status=DONE) self.sch.add_task(worker=WORKER, task_id="B", deps=["A"], status=DONE) self.sch.add_task(worker=WORKER, task_id="C", deps=["B"]) self.setTime(600) self.sch.ping(worker="MAYBE_ASSITANT") self.sch.prune() self.setTime(2000) self.sch.ping(worker="MAYBE_ASSITANT") self.sch.prune() self.assertEqual(set(expected), set(self.sch.task_list("", "").keys())) def test_prune_done_tasks_not_assistant(self, expected=None): # Here, MAYBE_ASSISTANT isnt an assistant self._test_prune_done_tasks(expected=[]) def test_keep_tasks_for_assistant(self): self.sch.get_work(worker="MAYBE_ASSITANT", assistant=True) # tell the scheduler this is an assistant self._test_prune_done_tasks([]) def test_keep_scheduler_disabled_tasks_for_assistant(self): self.sch.get_work(worker="MAYBE_ASSITANT", assistant=True) # tell the scheduler this is an assistant # create a scheduler disabled task and a worker disabled task for i in range(10): self.sch.add_task(worker=WORKER, task_id="D", status=FAILED) self.sch.add_task(worker=WORKER, task_id="E", status=DISABLED) # scheduler prunes the worker disabled task self.assertEqual({"D", "E"}, set(self.sch.task_list(DISABLED, ""))) self._test_prune_done_tasks([]) def test_keep_failed_tasks_for_assistant(self): self.sch.get_work(worker="MAYBE_ASSITANT", assistant=True) # tell the scheduler this is an assistant self.sch.add_task(worker=WORKER, task_id="D", status=FAILED, deps=["A"]) self._test_prune_done_tasks([]) def test_count_pending(self): for num_tasks in range(1, 20): self.sch.add_task(worker=WORKER, task_id=str(num_tasks), status=PENDING) expected = { "n_pending_tasks": num_tasks, "n_unique_pending": num_tasks, "n_pending_last_scheduled": num_tasks, "running_tasks": [], "worker_state": "active", } self.assertEqual(expected, self.sch.count_pending(WORKER)) def test_count_pending_include_failures(self): for num_tasks in range(1, 20): # must be scheduled as pending before failed to ensure WORKER is in the task's workers self.sch.add_task(worker=WORKER, task_id=str(num_tasks), status=PENDING) self.sch.add_task(worker=WORKER, task_id=str(num_tasks), status=FAILED) expected = { "n_pending_tasks": num_tasks, "n_unique_pending": num_tasks, "n_pending_last_scheduled": num_tasks, "running_tasks": [], "worker_state": "active", } self.assertEqual(expected, self.sch.count_pending(WORKER)) def test_count_pending_do_not_include_done_or_disabled(self): for num_tasks in range(1, 20, 2): self.sch.add_task(worker=WORKER, task_id=str(num_tasks), status=PENDING) self.sch.add_task(worker=WORKER, task_id=str(num_tasks + 1), status=PENDING) self.sch.add_task(worker=WORKER, task_id=str(num_tasks), status=DONE) self.sch.add_task(worker=WORKER, task_id=str(num_tasks + 1), status=DISABLED) expected = { "n_pending_tasks": 0, "n_unique_pending": 0, "n_pending_last_scheduled": 0, "running_tasks": [], "worker_state": "active", } self.assertEqual(expected, self.sch.count_pending(WORKER)) def test_count_pending_on_disabled_worker(self): self.sch.add_task(worker=WORKER, task_id="A") self.sch.add_task(worker="other", task_id="B") # needed to trigger right get_tasks code path self.assertEqual(1, self.sch.count_pending(WORKER)["n_pending_tasks"]) self.sch.disable_worker(WORKER) self.assertEqual(0, self.sch.count_pending(WORKER)["n_pending_tasks"]) def test_count_pending_do_not_count_upstream_disabled(self): self.sch.add_task(worker=WORKER, task_id="A", status=PENDING) self.sch.add_task(worker=WORKER, task_id="B", status=DISABLED) self.sch.add_task(worker=WORKER, task_id="C", status=PENDING, deps=["A", "B"]) expected = { "n_pending_tasks": 1, "n_unique_pending": 1, "n_pending_last_scheduled": 1, "running_tasks": [], "worker_state": "active", } self.assertEqual(expected, self.sch.count_pending(WORKER)) def test_count_pending_count_upstream_failed(self): self.sch.add_task(worker=WORKER, task_id="A", status=PENDING) self.sch.add_task(worker=WORKER, task_id="A", status=FAILED) self.sch.add_task(worker=WORKER, task_id="B", status=PENDING, deps=["A"]) expected = { "n_pending_tasks": 2, "n_unique_pending": 2, "n_pending_last_scheduled": 2, "running_tasks": [], "worker_state": "active", } self.assertEqual(expected, self.sch.count_pending(WORKER)) def test_count_pending_missing_worker(self): self.sch.add_task(worker=WORKER, task_id="A", status=PENDING) expected = { "n_pending_tasks": 0, "n_unique_pending": 0, "n_pending_last_scheduled": 0, "running_tasks": [], "worker_state": "active", } self.assertEqual(expected, self.sch.count_pending("other_worker")) def test_count_pending_uniques(self): self.sch.add_task(worker=WORKER, task_id="A", status=PENDING) self.sch.add_task(worker=WORKER, task_id="B", status=PENDING) self.sch.add_task(worker=WORKER, task_id="C", status=PENDING) self.sch.add_task(worker="other_worker", task_id="A", status=PENDING) expected = { "n_pending_tasks": 3, "n_unique_pending": 2, "n_pending_last_scheduled": 2, "running_tasks": [], "worker_state": "active", } self.assertEqual(expected, self.sch.count_pending(WORKER)) def test_count_pending_last_scheduled(self): self.sch.add_task(worker=WORKER, task_id="A", status=PENDING) self.sch.add_task(worker=WORKER, task_id="B", status=PENDING) self.sch.add_task(worker=WORKER, task_id="C", status=PENDING) self.sch.add_task(worker="other_worker", task_id="A", status=PENDING) self.sch.add_task(worker="other_worker", task_id="B", status=PENDING) self.sch.add_task(worker="other_worker", task_id="C", status=PENDING) expected = { "n_pending_tasks": 3, "n_unique_pending": 0, "n_pending_last_scheduled": 0, "running_tasks": [], "worker_state": "active", } self.assertEqual(expected, self.sch.count_pending(WORKER)) expected_other_worker = { "n_pending_tasks": 3, "n_unique_pending": 0, "n_pending_last_scheduled": 3, "running_tasks": [], "worker_state": "active", } self.assertEqual(expected_other_worker, self.sch.count_pending("other_worker")) def test_count_pending_disabled_worker(self): self.sch.add_task(worker=WORKER, task_id="A", status=PENDING) expected_active_state = { "n_pending_tasks": 1, "n_unique_pending": 1, "n_pending_last_scheduled": 1, "running_tasks": [], "worker_state": "active", } self.assertEqual(expected_active_state, self.sch.count_pending(worker=WORKER)) expected_disabled_state = { "n_pending_tasks": 0, "n_unique_pending": 0, "n_pending_last_scheduled": 0, "running_tasks": [], "worker_state": "disabled", } self.sch.disable_worker(worker=WORKER) self.assertEqual(expected_disabled_state, self.sch.count_pending(worker=WORKER)) def test_count_pending_running_tasks(self): self.sch.add_task(worker=WORKER, task_id="A", status=PENDING) self.assertEqual("A", self.sch.get_work(worker=WORKER)["task_id"]) expected_active_state = { "n_pending_tasks": 0, "n_unique_pending": 0, "n_pending_last_scheduled": 0, "running_tasks": [{"task_id": "A", "worker": "myworker"}], "worker_state": "active", } self.assertEqual(expected_active_state, self.sch.count_pending(worker=WORKER)) def test_scheduler_resources_none_allow_one(self): self.sch.add_task(worker="X", task_id="A", resources={"R1": 1}) self.assertEqual(self.sch.get_work(worker="X")["task_id"], "A") def test_scheduler_resources_none_disallow_two(self): self.sch.add_task(worker="X", task_id="A", resources={"R1": 2}) self.assertFalse(self.sch.get_work(worker="X")["task_id"], "A") def test_scheduler_with_insufficient_resources(self): self.sch.add_task(worker="X", task_id="A", resources={"R1": 3}) self.sch.update_resources(R1=2) self.assertFalse(self.sch.get_work(worker="X")["task_id"]) def test_scheduler_with_sufficient_resources(self): self.sch.add_task(worker="X", task_id="A", resources={"R1": 3}) self.sch.update_resources(R1=3) self.assertEqual(self.sch.get_work(worker="X")["task_id"], "A") def test_scheduler_with_resources_used(self): self.sch.add_task(worker="X", task_id="A", resources={"R1": 1}) self.assertEqual(self.sch.get_work(worker="X")["task_id"], "A") self.sch.add_task(worker="Y", task_id="B", resources={"R1": 1}) self.sch.update_resources(R1=1) self.assertFalse(self.sch.get_work(worker="Y")["task_id"]) def test_scheduler_overprovisioned_on_other_resource(self): self.sch.add_task(worker="X", task_id="A", resources={"R1": 2}) self.sch.update_resources(R1=2) self.assertEqual(self.sch.get_work(worker="X")["task_id"], "A") self.sch.add_task(worker="Y", task_id="B", resources={"R2": 2}) self.sch.update_resources(R1=1, R2=2) self.assertEqual(self.sch.get_work(worker="Y")["task_id"], "B") def test_scheduler_with_priority_and_competing_resources(self): self.sch.add_task(worker="X", task_id="A") self.assertEqual(self.sch.get_work(worker="X")["task_id"], "A") self.sch.add_task(worker="X", task_id="B", resources={"R": 1}, priority=10) self.sch.add_task(worker="Y", task_id="C", resources={"R": 1}, priority=1) self.sch.update_resources(R=1) self.assertFalse(self.sch.get_work(worker="Y")["task_id"]) self.sch.add_task(worker="Y", task_id="D", priority=0) self.assertEqual(self.sch.get_work(worker="Y")["task_id"], "D") def test_do_not_lock_resources_when_not_ready(self): """Test to make sure that resources won't go unused waiting on workers""" self.sch.add_task(worker="X", task_id="A", priority=10) self.sch.add_task(worker="X", task_id="B", resources={"R": 1}, priority=5) self.sch.add_task(worker="Y", task_id="C", resources={"R": 1}, priority=1) self.sch.update_resources(R=1) self.sch.add_worker("X", [("workers", 1)]) self.assertEqual("C", self.sch.get_work(worker="Y")["task_id"]) def test_lock_resources_when_one_of_multiple_workers_is_ready(self): self.sch.get_work(worker="X") # indicate to the scheduler that X is active self.sch.add_task(worker="X", task_id="A", priority=10) self.sch.add_task(worker="X", task_id="B", resources={"R": 1}, priority=5) self.sch.add_task(worker="Y", task_id="C", resources={"R": 1}, priority=1) self.sch.update_resources(R=1) self.sch.add_worker("X", [("workers", 2)]) self.sch.add_worker("Y", []) self.assertFalse(self.sch.get_work(worker="Y")["task_id"]) def test_do_not_lock_resources_while_running_higher_priority(self): """Test to make sure that resources won't go unused waiting on workers""" self.sch.add_task(worker="X", task_id="A", priority=10) self.sch.add_task(worker="X", task_id="B", resources={"R": 1}, priority=5) self.sch.add_task(worker="Y", task_id="C", resources={"R": 1}, priority=1) self.sch.update_resources(R=1) self.sch.add_worker("X", [("workers", 1)]) self.assertEqual("A", self.sch.get_work(worker="X")["task_id"]) self.assertEqual("C", self.sch.get_work(worker="Y")["task_id"]) def test_lock_resources_while_running_lower_priority(self): """Make sure resources will be made available while working on lower priority tasks""" self.sch.add_task(worker="X", task_id="A", priority=4) self.assertEqual("A", self.sch.get_work(worker="X")["task_id"]) self.sch.add_task(worker="X", task_id="B", resources={"R": 1}, priority=5) self.sch.add_task(worker="Y", task_id="C", resources={"R": 1}, priority=1) self.sch.update_resources(R=1) self.sch.add_worker("X", [("workers", 1)]) self.assertFalse(self.sch.get_work(worker="Y")["task_id"]) def test_lock_resources_for_second_worker(self): self.sch.get_work(worker="Y") # indicate to the scheduler that Y is active self.sch.add_task(worker="X", task_id="A", resources={"R": 1}) self.sch.add_task(worker="X", task_id="B", resources={"R": 1}) self.sch.add_task(worker="Y", task_id="C", resources={"R": 1}, priority=10) self.sch.add_worker("X", {"workers": 2}) self.sch.add_worker("Y", {"workers": 1}) self.sch.update_resources(R=2) self.assertEqual("A", self.sch.get_work(worker="X")["task_id"]) self.assertFalse(self.sch.get_work(worker="X")["task_id"]) def test_can_work_on_lower_priority_while_waiting_for_resources(self): self.sch.add_task(worker="X", task_id="A", resources={"R": 1}, priority=0) self.assertEqual("A", self.sch.get_work(worker="X")["task_id"]) self.sch.add_task(worker="Y", task_id="B", resources={"R": 1}, priority=10) self.sch.add_task(worker="Y", task_id="C", priority=0) self.sch.update_resources(R=1) self.assertEqual("C", self.sch.get_work(worker="Y")["task_id"]) def validate_resource_count(self, name, count): counts = {resource["name"]: resource["num_total"] for resource in self.sch.resource_list()} self.assertEqual(count, counts.get(name)) def test_update_new_resource(self): self.validate_resource_count("new_resource", None) # new_resource is not in the scheduler self.sch.update_resource("new_resource", 1) self.validate_resource_count("new_resource", 1) def test_update_existing_resource(self): self.sch.update_resource("new_resource", 1) self.sch.update_resource("new_resource", 2) self.validate_resource_count("new_resource", 2) def test_disable_existing_resource(self): self.sch.update_resource("new_resource", 1) self.sch.update_resource("new_resource", 0) self.validate_resource_count("new_resource", 0) def test_attempt_to_set_resource_to_negative_value(self): self.sch.update_resource("new_resource", 1) self.assertFalse(self.sch.update_resource("new_resource", -1)) self.validate_resource_count("new_resource", 1) def test_attempt_to_set_resource_to_non_integer(self): self.sch.update_resource("new_resource", 1) self.assertFalse(self.sch.update_resource("new_resource", 1.3)) self.assertFalse(self.sch.update_resource("new_resource", "1")) self.assertFalse(self.sch.update_resource("new_resource", None)) self.validate_resource_count("new_resource", 1) def test_priority_update_with_pruning(self): self.setTime(0) self.sch.add_task(task_id="A", worker="X") self.setTime(50) # after worker disconnects self.sch.prune() self.sch.add_task(task_id="B", deps=["A"], worker="X") self.setTime(2000) # after remove for task A self.sch.prune() # Here task A that B depends on is missing self.sch.add_task(worker=WORKER, task_id="C", deps=["B"], priority=100) self.sch.add_task(worker=WORKER, task_id="B", deps=["A"]) self.sch.add_task(worker=WORKER, task_id="A") self.sch.add_task(worker=WORKER, task_id="D", priority=10) self.check_task_order("ABCD") def test_update_resources(self): self.sch.add_task(worker=WORKER, task_id="A", deps=["B"]) self.sch.add_task(worker=WORKER, task_id="B", resources={"r": 2}) self.sch.update_resources(r=1) # B requires too many resources, we can't schedule self.check_task_order([]) self.sch.add_task(worker=WORKER, task_id="B", resources={"r": 1}) # now we have enough resources self.check_task_order(["B", "A"]) def test_handle_multiple_resources(self): self.sch.add_task(worker=WORKER, task_id="A", resources={"r1": 1, "r2": 1}) self.sch.add_task(worker=WORKER, task_id="B", resources={"r1": 1, "r2": 1}) self.sch.add_task(worker=WORKER, task_id="C", resources={"r1": 1}) self.sch.update_resources(r1=2, r2=1) self.assertEqual("A", self.sch.get_work(worker=WORKER)["task_id"]) self.check_task_order("C") def test_single_resource_lock(self): self.sch.add_task(worker="X", task_id="A", resources={"r": 1}) self.assertEqual("A", self.sch.get_work(worker="X")["task_id"]) self.sch.add_task(worker=WORKER, task_id="B", resources={"r": 2}, priority=10) self.sch.add_task(worker=WORKER, task_id="C", resources={"r": 1}) self.sch.update_resources(r=2) # Should wait for 2 units of r to be available for B before scheduling C self.check_task_order([]) def test_no_lock_if_too_many_resources_required(self): self.sch.add_task(worker=WORKER, task_id="A", resources={"r": 2}, priority=10) self.sch.add_task(worker=WORKER, task_id="B", resources={"r": 1}) self.sch.update_resources(r=1) self.check_task_order("B") def test_multiple_resources_lock(self): self.sch.get_work(worker="X") # indicate to the scheduler that X is active self.sch.add_task(worker="X", task_id="A", resources={"r1": 1, "r2": 1}, priority=10) self.sch.add_task(worker=WORKER, task_id="B", resources={"r2": 1}) self.sch.add_task(worker=WORKER, task_id="C", resources={"r1": 1}) self.sch.update_resources(r1=1, r2=1) # should preserve both resources for worker 'X' self.check_task_order([]) def test_multiple_resources_no_lock(self): self.sch.add_task(worker=WORKER, task_id="A", resources={"r1": 1}, priority=10) self.sch.add_task(worker=WORKER, task_id="B", resources={"r1": 1, "r2": 1}, priority=10) self.sch.add_task(worker=WORKER, task_id="C", resources={"r2": 1}) self.sch.update_resources(r1=1, r2=2) self.assertEqual("A", self.sch.get_work(worker=WORKER)["task_id"]) # C doesn't block B, so it can go first self.check_task_order("C") def test_do_not_allow_stowaway_resources(self): self.sch.add_task_batcher(worker=WORKER, task_family="A", batched_args=["a"]) self.sch.add_task(worker=WORKER, task_id="A1", resources={"r1": 1}, family="A", params={"a": "1"}, batchable=True, priority=1) self.sch.add_task(worker=WORKER, task_id="A2", resources={"r1": 2}, family="A", params={"a": "2"}, batchable=True) self.sch.add_task(worker=WORKER, task_id="A3", resources={"r2": 1}, family="A", params={"a": "3"}, batchable=True) self.sch.add_task(worker=WORKER, task_id="A4", resources={"r1": 1}, family="A", params={"a": "4"}, batchable=True) self.assertEqual({"A1", "A4"}, set(self.sch.get_work(worker=WORKER)["batch_task_ids"])) def test_do_not_allow_same_resources(self): self.sch.add_task_batcher(worker=WORKER, task_family="A", batched_args=["a"]) self.sch.add_task(worker=WORKER, task_id="A1", resources={"r1": 1}, family="A", params={"a": "1"}, batchable=True, priority=1) self.sch.add_task(worker=WORKER, task_id="A2", resources={"r1": 1}, family="A", params={"a": "2"}, batchable=True) self.sch.add_task(worker=WORKER, task_id="A3", resources={"r1": 1}, family="A", params={"a": "3"}, batchable=True) self.sch.add_task(worker=WORKER, task_id="A4", resources={"r1": 1}, family="A", params={"a": "4"}, batchable=True) self.assertEqual({"A1", "A2", "A3", "A4"}, set(self.sch.get_work(worker=WORKER)["batch_task_ids"])) def test_change_resources_on_running_task(self): self.sch.add_task(worker=WORKER, task_id="A1", resources={"a": 1}, priority=10) self.sch.add_task(worker=WORKER, task_id="A2", resources={"a": 1}, priority=1) self.assertEqual("A1", self.sch.get_work(worker=WORKER)["task_id"]) self.assertIsNone(self.sch.get_work(worker=WORKER)["task_id"]) # switch the resource of the running task self.sch.add_task(worker="other", task_id="A1", resources={"b": 1}, priority=1) # the running task should be using the resource it had when it started running self.assertIsNone(self.sch.get_work(worker=WORKER)["task_id"]) def test_interleave_resource_change_and_get_work(self): for i in range(100): self.sch.add_task(worker=WORKER, task_id="A{}".format(i), resources={"a": 1}, priority=100 - i) for i in range(100): self.sch.get_work(worker=WORKER) self.sch.add_task(worker="other", task_id="A{}".format(i), resources={"b": 1}, priority=100 - i) # we should only see 1 task per resource rather than all 100 tasks running self.assertEqual(2, len(self.sch.task_list(RUNNING, ""))) def test_assistant_has_different_resources_than_scheduled_max_task_id(self): self.sch.add_task_batcher(worker="assistant", task_family="A", batched_args=["a"], max_batch_size=2) self.sch.add_task(worker=WORKER, task_id="A1", resources={"a": 1}, family="A", params={"a": "1"}, batchable=True, priority=1) self.sch.add_task(worker=WORKER, task_id="A2", resources={"a": 1}, family="A", params={"a": "2"}, batchable=True, priority=2) self.sch.add_task(worker=WORKER, task_id="A3", resources={"a": 1}, family="A", params={"a": "3"}, batchable=True, priority=3) result = self.sch.get_work(worker="assistant", assistant=True) self.assertEqual({"A3", "A2"}, set(result["batch_task_ids"])) self.sch.add_task(worker="assistant", task_id="A3", status=RUNNING, batch_id=result["batch_id"], resources={"b": 1}) # the assistant changed the status, but only after it was batch running self.assertIsNone(self.sch.get_work(worker=WORKER)["task_id"]) def test_assistant_has_different_resources_than_scheduled_new_task_id(self): self.sch.add_task_batcher(worker="assistant", task_family="A", batched_args=["a"], max_batch_size=2) self.sch.add_task(worker=WORKER, task_id="A1", resources={"a": 1}, family="A", params={"a": "1"}, batchable=True, priority=1) self.sch.add_task(worker=WORKER, task_id="A2", resources={"a": 1}, family="A", params={"a": "2"}, batchable=True, priority=2) self.sch.add_task(worker=WORKER, task_id="A3", resources={"a": 1}, family="A", params={"a": "3"}, batchable=True, priority=3) result = self.sch.get_work(worker="assistant", assistant=True) self.assertEqual({"A3", "A2"}, set(result["batch_task_ids"])) self.sch.add_task(worker="assistant", task_id="A_2_3", status=RUNNING, batch_id=result["batch_id"], resources={"b": 1}) # the assistant changed the status, but only after it was batch running self.assertIsNone(self.sch.get_work(worker=WORKER)["task_id"]) def test_assistant_has_different_resources_than_scheduled_max_task_id_during_scheduling(self): self.sch.add_task_batcher(worker="assistant", task_family="A", batched_args=["a"], max_batch_size=2) self.sch.add_task(worker=WORKER, task_id="A1", resources={"a": 1}, family="A", params={"a": "1"}, batchable=True, priority=1) self.sch.add_task(worker=WORKER, task_id="A2", resources={"a": 1}, family="A", params={"a": "2"}, batchable=True, priority=2) self.sch.add_task(worker=WORKER, task_id="A3", resources={"a": 1}, family="A", params={"a": "3"}, batchable=True, priority=3) result = self.sch.get_work(worker="assistant", assistant=True) self.assertEqual({"A3", "A2"}, set(result["batch_task_ids"])) self.sch.add_task(worker=WORKER, task_id="A2", resources={"b": 1}, family="A", params={"a": "2"}, batchable=True, priority=2) self.sch.add_task(worker=WORKER, task_id="A3", resources={"b": 1}, family="A", params={"a": "3"}, batchable=True, priority=3) self.sch.add_task(worker="assistant", task_id="A3", status=RUNNING, batch_id=result["batch_id"], resources={"b": 1}) # the statuses changed, but only after they wree batch running self.assertIsNone(self.sch.get_work(worker=WORKER)["task_id"]) def test_assistant_has_different_resources_than_scheduled_new_task_id_during_scheduling(self): self.sch.add_task_batcher(worker="assistant", task_family="A", batched_args=["a"], max_batch_size=2) self.sch.add_task(worker=WORKER, task_id="A1", resources={"a": 1}, family="A", params={"a": "1"}, batchable=True, priority=1) self.sch.add_task(worker=WORKER, task_id="A2", resources={"a": 1}, family="A", params={"a": "2"}, batchable=True, priority=2) self.sch.add_task(worker=WORKER, task_id="A3", resources={"a": 1}, family="A", params={"a": "3"}, batchable=True, priority=3) result = self.sch.get_work(worker="assistant", assistant=True) self.assertEqual({"A3", "A2"}, set(result["batch_task_ids"])) self.sch.add_task(worker=WORKER, task_id="A2", resources={"b": 1}, family="A", params={"a": "2"}, batchable=True, priority=2) self.sch.add_task(worker=WORKER, task_id="A3", resources={"b": 1}, family="A", params={"a": "3"}, batchable=True, priority=3) self.sch.add_task(worker="assistant", task_id="A_2_3", status=RUNNING, batch_id=result["batch_id"], resources={"b": 1}) # the statuses changed, but only after they were batch running self.assertIsNone(self.sch.get_work(worker=WORKER)["task_id"]) def test_allow_resource_use_while_scheduling(self): self.sch.update_resources(r1=1) self.sch.add_task(worker="SCHEDULING", task_id="A", resources={"r1": 1}, priority=10) self.sch.add_task(worker=WORKER, task_id="B", resources={"r1": 1}, priority=1) self.assertEqual("B", self.sch.get_work(worker=WORKER)["task_id"]) def test_stop_locking_resource_for_uninterested_worker(self): self.setTime(0) self.sch.update_resources(r1=1) self.assertIsNone(self.sch.get_work(worker=WORKER)["task_id"]) self.sch.add_task(worker=WORKER, task_id="A", resources={"r1": 1}, priority=10) self.sch.add_task(worker="LOW_PRIO", task_id="B", resources={"r1": 1}, priority=1) self.assertIsNone(self.sch.get_work(worker="LOW_PRIO")["task_id"]) self.setTime(120) self.assertEqual("B", self.sch.get_work(worker="LOW_PRIO")["task_id"]) def check_task_order(self, order): for expected_id in order: self.assertEqual(self.sch.get_work(worker=WORKER)["task_id"], expected_id) self.sch.add_task(worker=WORKER, task_id=expected_id, status=DONE) self.assertEqual(self.sch.get_work(worker=WORKER)["task_id"], None) def test_priorities(self): self.sch.add_task(worker=WORKER, task_id="A", priority=10) self.sch.add_task(worker=WORKER, task_id="B", priority=5) self.sch.add_task(worker=WORKER, task_id="C", priority=15) self.sch.add_task(worker=WORKER, task_id="D", priority=9) self.check_task_order(["C", "A", "D", "B"]) def test_priorities_default_and_negative(self): self.sch.add_task(worker=WORKER, task_id="A", priority=10) self.sch.add_task(worker=WORKER, task_id="B") self.sch.add_task(worker=WORKER, task_id="C", priority=15) self.sch.add_task(worker=WORKER, task_id="D", priority=-20) self.sch.add_task(worker=WORKER, task_id="E", priority=1) self.check_task_order(["C", "A", "E", "B", "D"]) def test_priorities_and_dependencies(self): self.sch.add_task(worker=WORKER, task_id="A", deps=["Z"], priority=10) self.sch.add_task(worker=WORKER, task_id="B", priority=5) self.sch.add_task(worker=WORKER, task_id="C", deps=["Z"], priority=3) self.sch.add_task(worker=WORKER, task_id="D", priority=2) self.sch.add_task(worker=WORKER, task_id="Z", priority=1) self.check_task_order(["Z", "A", "B", "C", "D"]) def test_priority_update_dependency_after_scheduling(self): self.sch.add_task(worker=WORKER, task_id="A", priority=1) self.sch.add_task(worker=WORKER, task_id="B", priority=5, deps=["A"]) self.sch.add_task(worker=WORKER, task_id="C", priority=10, deps=["B"]) self.sch.add_task(worker=WORKER, task_id="D", priority=6) self.check_task_order(["A", "B", "C", "D"]) def test_disable(self): self.sch.add_task(worker=WORKER, task_id="A") self.sch.add_task(worker=WORKER, task_id="A", status=FAILED) self.sch.add_task(worker=WORKER, task_id="A", status=FAILED) self.sch.add_task(worker=WORKER, task_id="A", status=FAILED) # should be disabled at this point self.assertEqual(len(self.sch.task_list("DISABLED", "")), 1) self.assertEqual(len(self.sch.task_list("FAILED", "")), 0) self.sch.add_task(worker=WORKER, task_id="A") self.assertEqual(self.sch.get_work(worker=WORKER)["task_id"], None) def test_disable_and_reenable(self): self.sch.add_task(worker=WORKER, task_id="A") self.sch.add_task(worker=WORKER, task_id="A", status=FAILED) self.sch.add_task(worker=WORKER, task_id="A", status=FAILED) self.sch.add_task(worker=WORKER, task_id="A", status=FAILED) # should be disabled at this point self.assertEqual(len(self.sch.task_list("DISABLED", "")), 1) self.assertEqual(len(self.sch.task_list("FAILED", "")), 0) self.sch.re_enable_task("A") # should be enabled at this point self.assertEqual(len(self.sch.task_list("DISABLED", "")), 0) self.assertEqual(len(self.sch.task_list("FAILED", "")), 1) self.sch.add_task(worker=WORKER, task_id="A") self.assertEqual(self.sch.get_work(worker=WORKER)["task_id"], "A") def test_disable_and_reenable_and_disable_again(self): self.sch.add_task(worker=WORKER, task_id="A") self.sch.add_task(worker=WORKER, task_id="A", status=FAILED) self.sch.add_task(worker=WORKER, task_id="A", status=FAILED) self.sch.add_task(worker=WORKER, task_id="A", status=FAILED) # should be disabled at this point self.assertEqual(len(self.sch.task_list("DISABLED", "")), 1) self.assertEqual(len(self.sch.task_list("FAILED", "")), 0) self.sch.re_enable_task("A") # should be enabled at this point self.assertEqual(len(self.sch.task_list("DISABLED", "")), 0) self.assertEqual(len(self.sch.task_list("FAILED", "")), 1) self.sch.add_task(worker=WORKER, task_id="A") self.assertEqual(self.sch.get_work(worker=WORKER)["task_id"], "A") self.sch.add_task(worker=WORKER, task_id="A", status=FAILED) # should be still enabled self.assertEqual(len(self.sch.task_list("DISABLED", "")), 0) self.assertEqual(len(self.sch.task_list("FAILED", "")), 1) self.sch.add_task(worker=WORKER, task_id="A") self.assertEqual(self.sch.get_work(worker=WORKER)["task_id"], "A") self.sch.add_task(worker=WORKER, task_id="A", status=FAILED) self.sch.add_task(worker=WORKER, task_id="A", status=FAILED) # should be disabled now self.assertEqual(len(self.sch.task_list("DISABLED", "")), 1) self.assertEqual(len(self.sch.task_list("FAILED", "")), 0) self.sch.add_task(worker=WORKER, task_id="A") self.assertEqual(self.sch.get_work(worker=WORKER)["task_id"], None) def test_disable_and_done(self): self.sch.add_task(worker=WORKER, task_id="A") self.sch.add_task(worker=WORKER, task_id="A", status=FAILED) self.sch.add_task(worker=WORKER, task_id="A", status=FAILED) self.sch.add_task(worker=WORKER, task_id="A", status=FAILED) # should be disabled at this point self.assertEqual(len(self.sch.task_list("DISABLED", "")), 1) self.assertEqual(len(self.sch.task_list("FAILED", "")), 0) self.sch.add_task(worker=WORKER, task_id="A", status=DONE) # should be enabled at this point self.assertEqual(len(self.sch.task_list("DISABLED", "")), 0) self.assertEqual(len(self.sch.task_list("DONE", "")), 1) self.sch.add_task(worker=WORKER, task_id="A") self.assertEqual(self.sch.get_work(worker=WORKER)["task_id"], "A") def test_automatic_re_enable(self): self.sch = Scheduler(retry_count=2, disable_persist=100) self.setTime(0) self.sch.add_task(worker=WORKER, task_id="A", status=FAILED) self.sch.add_task(worker=WORKER, task_id="A", status=FAILED) # should be disabled now self.assertEqual(DISABLED, self.sch.task_list("", "")["A"]["status"]) # re-enables after 100 seconds self.setTime(101) self.assertEqual(FAILED, self.sch.task_list("", "")["A"]["status"]) def test_automatic_re_enable_with_one_failure_allowed(self): self.sch = Scheduler(retry_count=1, disable_persist=100) self.setTime(0) self.sch.add_task(worker=WORKER, task_id="A", status=FAILED) # should be disabled now self.assertEqual(DISABLED, self.sch.task_list("", "")["A"]["status"]) # re-enables after 100 seconds self.setTime(101) self.assertEqual(FAILED, self.sch.task_list("", "")["A"]["status"]) def test_no_automatic_re_enable_after_manual_disable(self): self.sch = Scheduler(disable_persist=100) self.setTime(0) self.sch.add_task(worker=WORKER, task_id="A", status=DISABLED) # should be disabled now self.assertEqual(DISABLED, self.sch.task_list("", "")["A"]["status"]) # should not re-enable after 100 seconds self.setTime(101) self.assertEqual(DISABLED, self.sch.task_list("", "")["A"]["status"]) def test_no_automatic_re_enable_after_auto_then_manual_disable(self): self.sch = Scheduler(retry_count=2, disable_persist=100) self.setTime(0) self.sch.add_task(worker=WORKER, task_id="A", status=FAILED) self.sch.add_task(worker=WORKER, task_id="A", status=FAILED) # should be disabled now self.assertEqual(DISABLED, self.sch.task_list("", "")["A"]["status"]) # should remain disabled once set self.sch.add_task(worker=WORKER, task_id="A", status=DISABLED) self.assertEqual(DISABLED, self.sch.task_list("", "")["A"]["status"]) # should not re-enable after 100 seconds self.setTime(101) self.assertEqual(DISABLED, self.sch.task_list("", "")["A"]["status"]) def test_disable_by_worker(self): self.sch.add_task(worker=WORKER, task_id="A", status=DISABLED) self.assertEqual(len(self.sch.task_list("DISABLED", "")), 1) self.sch.add_task(worker=WORKER, task_id="A") # should be enabled at this point self.assertEqual(len(self.sch.task_list("DISABLED", "")), 0) self.sch.add_task(worker=WORKER, task_id="A") self.assertEqual(self.sch.get_work(worker=WORKER)["task_id"], "A") def test_disable_worker(self): self.sch.add_task(worker=WORKER, task_id="A") self.sch.disable_worker(worker=WORKER) work = self.sch.get_work(worker=WORKER) self.assertEqual(0, work["n_unique_pending"]) self.assertEqual(0, work["n_pending_tasks"]) self.assertIsNone(work["task_id"]) def test_pause_work(self): self.sch.add_task(worker=WORKER, task_id="A") self.sch.pause() self.assertEqual( { "n_pending_last_scheduled": 1, "n_unique_pending": 1, "n_pending_tasks": 1, "running_tasks": [], "task_id": None, "worker_state": "active", }, self.sch.get_work(worker=WORKER), ) self.sch.unpause() self.assertEqual("A", self.sch.get_work(worker=WORKER)["task_id"]) def test_is_paused(self): self.assertFalse(self.sch.is_paused()["paused"]) self.sch.pause() self.assertTrue(self.sch.is_paused()["paused"]) self.sch.unpause() self.assertFalse(self.sch.is_paused()["paused"]) def test_disable_worker_leaves_jobs_running(self): self.sch.add_task(worker=WORKER, task_id="A") self.sch.get_work(worker=WORKER) self.sch.disable_worker(worker=WORKER) self.assertEqual(["A"], list(self.sch.task_list("RUNNING", "").keys())) self.assertEqual(["A"], list(self.sch.worker_list()[0]["running"].keys())) def test_disable_worker_cannot_pick_up_failed_jobs(self): self.setTime(0) self.sch.add_task(worker=WORKER, task_id="A") self.sch.get_work(worker=WORKER) self.sch.disable_worker(worker=WORKER) self.sch.add_task(worker=WORKER, task_id="A", status=FAILED) # increase time and prune to make the job pending again self.setTime(1000) self.sch.ping(worker=WORKER) self.sch.prune() # we won't try the job again self.assertIsNone(self.sch.get_work(worker=WORKER)["task_id"]) # not even if other stuff is pending, changing the pending tasks code path self.sch.add_task(worker="other_worker", task_id="B") self.assertIsNone(self.sch.get_work(worker=WORKER)["task_id"]) def test_disable_worker_cannot_continue_scheduling(self): self.sch.disable_worker(worker=WORKER) self.sch.add_task(worker=WORKER, task_id="A") self.assertIsNone(self.sch.get_work(worker=WORKER)["task_id"]) def test_disable_worker_cannot_add_tasks(self): """ Verify that a disabled worker cannot add tasks """ self.sch.disable_worker(worker=WORKER) self.sch.add_task(worker=WORKER, task_id="A") self.assertIsNone(self.sch.get_work(worker="assistant", assistant=True)["task_id"]) self.sch.add_task(worker="third_enabled_worker", task_id="A") self.assertIsNotNone(self.sch.get_work(worker="assistant", assistant=True)["task_id"]) def _test_disable_worker_helper(self, new_status, new_deps): self.sch.add_task(worker=WORKER, task_id="A") self.assertEqual("A", self.sch.get_work(worker=WORKER)["task_id"]) self.sch.disable_worker(worker=WORKER) self.assertEqual(["A"], list(self.sch.task_list("RUNNING", "").keys())) for dep in new_deps: self.sch.add_task(worker=WORKER, task_id=dep, status="PENDING") self.sch.add_task(worker=WORKER, task_id="A", status=new_status, new_deps=new_deps) self.assertFalse(self.sch.task_list("RUNNING", "").keys()) self.assertEqual(["A"], list(self.sch.task_list(new_status, "").keys())) self.assertIsNone(self.sch.get_work(worker=WORKER)["task_id"]) for task in self.sch.task_list("", "").values(): self.assertFalse(task["workers"]) def test_disable_worker_can_finish_task(self): self._test_disable_worker_helper(new_status=DONE, new_deps=[]) def test_disable_worker_can_fail_task(self): self._test_disable_worker_helper(new_status=FAILED, new_deps=[]) def test_disable_worker_stays_disabled_on_new_deps(self): self._test_disable_worker_helper(new_status="PENDING", new_deps=["B", "C"]) def test_disable_worker_assistant_gets_no_task(self): self.setTime(0) self.sch.add_task(worker=WORKER, task_id="A") self.sch.add_worker("assistant", [("assistant", True)]) self.sch.ping(worker="assistant") self.sch.disable_worker("assistant") self.assertIsNone(self.sch.get_work(worker="assistant", assistant=True)["task_id"]) self.assertIsNotNone(self.sch.get_work(worker=WORKER)["task_id"]) def test_prune_worker(self): self.setTime(1) self.sch.add_worker(worker=WORKER, info={}) self.setTime(10000) self.sch.prune() self.setTime(20000) self.sch.prune() self.assertFalse(self.sch.worker_list()) def test_task_list_beyond_limit(self): sch = Scheduler(max_shown_tasks=3) for c in "ABCD": sch.add_task(worker=WORKER, task_id=c) self.assertEqual(set("ABCD"), set(sch.task_list("PENDING", "", False).keys())) self.assertEqual({"num_tasks": 4}, sch.task_list("PENDING", "")) def test_task_list_within_limit(self): sch = Scheduler(max_shown_tasks=4) for c in "ABCD": sch.add_task(worker=WORKER, task_id=c) self.assertEqual(set("ABCD"), set(sch.task_list("PENDING", "").keys())) def test_task_lists_some_beyond_limit(self): sch = Scheduler(max_shown_tasks=3) for c in "ABCD": sch.add_task(worker=WORKER, task_id=c, status=DONE) for c in "EFG": sch.add_task(worker=WORKER, task_id=c) self.assertEqual(set("EFG"), set(sch.task_list("PENDING", "").keys())) self.assertEqual({"num_tasks": 4}, sch.task_list("DONE", "")) def test_dynamic_shown_tasks_in_task_list(self): sch = Scheduler(max_shown_tasks=3) for task_id in "ABCD": sch.add_task(worker=WORKER, task_id=task_id, status=DONE) for task_id in "EFG": sch.add_task(worker=WORKER, task_id=task_id) self.assertEqual(set("EFG"), set(sch.task_list("PENDING", "").keys())) self.assertEqual({"num_tasks": 3}, sch.task_list("PENDING", "", max_shown_tasks=2)) self.assertEqual({"num_tasks": 4}, sch.task_list("DONE", "")) self.assertEqual(set("ABCD"), set(sch.task_list("DONE", "", max_shown_tasks=4).keys())) def add_task(self, family, **params): task_id = str(hash((family, str(params)))) # use an unhelpful task id self.sch.add_task(worker=WORKER, family=family, params=params, task_id=task_id) return task_id def search_pending(self, term, expected_keys): actual_keys = set(self.sch.task_list("PENDING", "", search=term).keys()) self.assertEqual(expected_keys, actual_keys) def test_task_list_filter_by_search_family_name(self): task1 = self.add_task("MySpecialTask") task2 = self.add_task("OtherSpecialTask") self.search_pending("Special", {task1, task2}) self.search_pending("Task", {task1, task2}) self.search_pending("My", {task1}) self.search_pending("Other", {task2}) def test_task_list_filter_by_search_long_family_name(self): task = self.add_task("TaskClassWithAVeryLongNameAndDistinctEndingUUDDLRLRAB") self.search_pending("UUDDLRLRAB", {task}) def test_task_list_filter_by_param_name(self): task1 = self.add_task("ClassA", day="2016-02-01") task2 = self.add_task("ClassB", hour="2016-02-01T12") self.search_pending("day", {task1}) self.search_pending("hour", {task2}) def test_task_list_filter_by_long_param_name(self): task = self.add_task("ClassA", a_very_long_param_name_ending_with_uuddlrlrab="2016-02-01") self.search_pending("uuddlrlrab", {task}) def test_task_list_filter_by_param_value(self): task1 = self.add_task("ClassA", day="2016-02-01") task2 = self.add_task("ClassB", hour="2016-02-01T12") self.search_pending("2016-02-01", {task1, task2}) self.search_pending("T12", {task2}) def test_task_list_filter_by_long_param_value(self): task = self.add_task("ClassA", param="a_very_long_param_value_ending_with_uuddlrlrab") self.search_pending("uuddlrlrab", {task}) def test_task_list_filter_by_param_name_value_pair(self): task = self.add_task("ClassA", param="value") self.search_pending("param=value", {task}) def test_task_list_does_not_filter_by_task_id(self): task = self.add_task("Class") self.search_pending(task, set()) def test_task_list_filter_by_multiple_search_terms(self): expected = self.add_task("ClassA", day="2016-02-01", num="5") self.add_task("ClassA", day="2016-03-01", num="5") self.add_task("ClassB", day="2016-02-01", num="5") self.add_task("ClassA", day="2016-02-01", val="5") self.search_pending("ClassA 2016-02-01 num", {expected}) # ensure that the task search is case insensitive self.search_pending("classa 2016-02-01 num", {expected}) def test_upstream_beyond_limit(self): sch = Scheduler(max_shown_tasks=3) for i in range(4): sch.add_task(worker=WORKER, family="Test", params={"p": str(i)}, task_id="Test_%i" % i) self.assertEqual({"num_tasks": -1}, sch.task_list("PENDING", "FAILED")) self.assertEqual({"num_tasks": 4}, sch.task_list("PENDING", "")) def test_do_not_prune_on_beyond_limit_check(self): sch = Scheduler(max_shown_tasks=3) sch.prune = mock.Mock() for i in range(4): sch.add_task(worker=WORKER, family="Test", params={"p": str(i)}, task_id="Test_%i" % i) self.assertEqual({"num_tasks": 4}, sch.task_list("PENDING", "")) sch.prune.assert_not_called() def test_search_results_beyond_limit(self): sch = Scheduler(max_shown_tasks=3) for i in range(4): sch.add_task(worker=WORKER, family="Test", params={"p": str(i)}, task_id="Test_%i" % i) self.assertEqual({"num_tasks": 4}, sch.task_list("PENDING", "", search="Test")) self.assertEqual(["Test_0"], list(sch.task_list("PENDING", "", search="0").keys())) def test_priority_update_dependency_chain(self): self.sch.add_task(worker=WORKER, task_id="A", priority=10, deps=["B"]) self.sch.add_task(worker=WORKER, task_id="B", priority=5, deps=["C"]) self.sch.add_task(worker=WORKER, task_id="C", priority=1) self.sch.add_task(worker=WORKER, task_id="D", priority=6) self.check_task_order(["C", "B", "A", "D"]) def test_priority_no_decrease_with_multiple_updates(self): self.sch.add_task(worker=WORKER, task_id="A", priority=1) self.sch.add_task(worker=WORKER, task_id="B", priority=10, deps=["A"]) self.sch.add_task(worker=WORKER, task_id="C", priority=5, deps=["A"]) self.sch.add_task(worker=WORKER, task_id="D", priority=6) self.check_task_order(["A", "B", "D", "C"]) def test_unique_tasks(self): self.sch.add_task(worker=WORKER, task_id="A") self.sch.add_task(worker=WORKER, task_id="B") self.sch.add_task(worker=WORKER, task_id="C") self.sch.add_task(worker=WORKER + "_2", task_id="B") response = self.sch.get_work(worker=WORKER) self.assertEqual(3, response["n_pending_tasks"]) self.assertEqual(2, response["n_unique_pending"]) def test_pending_downstream_disable(self): self.sch.add_task(worker=WORKER, task_id="A", status=DISABLED) self.sch.add_task(worker=WORKER, task_id="B", deps=("A",)) self.sch.add_task(worker=WORKER, task_id="C", deps=("B",)) response = self.sch.get_work(worker=WORKER) self.assertTrue(response["task_id"] is None) self.assertEqual(0, response["n_pending_tasks"]) self.assertEqual(0, response["n_unique_pending"]) def test_pending_downstream_failure(self): self.sch.add_task(worker=WORKER, task_id="A", status=FAILED) self.sch.add_task(worker=WORKER, task_id="B", deps=("A",)) self.sch.add_task(worker=WORKER, task_id="C", deps=("B",)) response = self.sch.get_work(worker=WORKER) self.assertTrue(response["task_id"] is None) self.assertEqual(2, response["n_pending_tasks"]) self.assertEqual(2, response["n_unique_pending"]) def test_task_list_no_deps(self): self.sch.add_task(worker=WORKER, task_id="B", deps=("A",)) self.sch.add_task(worker=WORKER, task_id="A") task_list = self.sch.task_list("PENDING", "") self.assertFalse("deps" in task_list["A"]) def test_task_first_failure_time(self): self.sch.add_task(worker=WORKER, task_id="A") test_task = self.sch._state.get_task("A") self.assertIsNone(test_task.first_failure_time) time_before_failure = time.time() test_task.add_failure() time_after_failure = time.time() self.assertLessEqual(time_before_failure, test_task.first_failure_time) self.assertGreaterEqual(time_after_failure, test_task.first_failure_time) def test_task_first_failure_time_remains_constant(self): self.sch.add_task(worker=WORKER, task_id="A") test_task = self.sch._state.get_task("A") self.assertIsNone(test_task.first_failure_time) test_task.add_failure() first_failure_time = test_task.first_failure_time test_task.add_failure() self.assertEqual(first_failure_time, test_task.first_failure_time) def test_task_has_excessive_failures(self): self.sch.add_task(worker=WORKER, task_id="A") test_task = self.sch._state.get_task("A") self.assertIsNone(test_task.first_failure_time) self.assertFalse(test_task.has_excessive_failures()) test_task.add_failure() self.assertFalse(test_task.has_excessive_failures()) fake_failure_time = test_task.first_failure_time - 2 * 60 * 60 test_task.first_failure_time = fake_failure_time self.assertTrue(test_task.has_excessive_failures()) def test_quadratic_behavior(self): """Test that get_work is not taking linear amount of time. This is of course impossible to test, however, doing reasonable assumptions about hardware. This time should finish in a timely manner. """ # For 10000 it takes almost 1 second on my laptop. Prior to these # changes it was being slow already at NUM_TASKS=300 NUM_TASKS = 10000 for i in range(NUM_TASKS): self.sch.add_task(worker=str(i), task_id=str(i), resources={}) for i in range(NUM_TASKS): self.assertEqual(self.sch.get_work(worker=str(i))["task_id"], str(i)) self.sch.add_task(worker=str(i), task_id=str(i), status=DONE) def test_get_work_speed(self): """Test that get_work is fast for few workers and many DONEs. In #986, @daveFNbuck reported that he got a slowdown. """ # This took almost 4 minutes without optimization. # Now it takes 10 seconds on my machine. NUM_PENDING = 1000 NUM_DONE = 200000 assert NUM_DONE >= NUM_PENDING for i in range(NUM_PENDING): self.sch.add_task(worker=WORKER, task_id=str(i), resources={}) for i in range(NUM_PENDING, NUM_DONE): self.sch.add_task(worker=WORKER, task_id=str(i), status=DONE) for i in range(NUM_PENDING): res = int(self.sch.get_work(worker=WORKER)["task_id"]) self.assertTrue(0 <= res < NUM_PENDING) self.sch.add_task(worker=WORKER, task_id=str(res), status=DONE) def test_assistants_dont_nurture_finished_statuses(self): """ Test how assistants affect longevity of tasks Assistants should not affect longevity expect for the tasks that it is running, par the one it's actually running. """ self.sch = Scheduler(retry_delay=100000000000) # Never pendify failed tasks self.setTime(1) self.sch.add_worker("assistant", [("assistant", True)]) self.sch.ping(worker="assistant") self.sch.add_task(worker="uploader", task_id="running", status=PENDING) self.assertEqual(self.sch.get_work(worker="assistant", assistant=True)["task_id"], "running") self.setTime(2) self.sch.add_task(worker="uploader", task_id="done", status=DONE) self.sch.add_task(worker="uploader", task_id="disabled", status=DISABLED) self.sch.add_task(worker="uploader", task_id="pending", status=PENDING) self.sch.add_task(worker="uploader", task_id="failed", status=FAILED) self.sch.add_task(worker="uploader", task_id="unknown", status=UNKNOWN) self.setTime(100000) self.sch.ping(worker="assistant") self.sch.prune() self.setTime(200000) self.sch.ping(worker="assistant") self.sch.prune() nurtured_statuses = [RUNNING] not_nurtured_statuses = [DONE, UNKNOWN, DISABLED, PENDING, FAILED] for status in nurtured_statuses: self.assertEqual(set([status.lower()]), set(self.sch.task_list(status, ""))) for status in not_nurtured_statuses: self.assertEqual(set([]), set(self.sch.task_list(status, ""))) self.assertEqual(1, len(self.sch.task_list(None, ""))) # None == All statuses def test_no_crash_on_only_disable_hard_timeout(self): """ Scheduler shouldn't crash with only disable_hard_timeout There was some failure happening when disable_hard_timeout was set but disable_failures was not. """ self.sch = Scheduler(retry_delay=5, disable_hard_timeout=100) self.setTime(1) self.sch.add_worker(WORKER, []) self.sch.ping(worker=WORKER) self.setTime(2) self.sch.add_task(worker=WORKER, task_id="A") self.sch.add_task(worker=WORKER, task_id="B", deps=["A"]) self.assertEqual(self.sch.get_work(worker=WORKER)["task_id"], "A") self.sch.add_task(worker=WORKER, task_id="A", status=FAILED) self.setTime(10) self.sch.prune() self.assertEqual(self.sch.get_work(worker=WORKER)["task_id"], "A") def test_assistant_running_task_dont_disappear(self): """ Tasks run by an assistant shouldn't be pruned """ self.setTime(1) self.sch.add_worker(WORKER, []) self.sch.ping(worker=WORKER) self.setTime(2) self.sch.add_task(worker=WORKER, task_id="A") self.assertEqual(self.sch.get_work(worker=WORKER)["task_id"], "A") self.sch.add_task(worker=WORKER, task_id="B") self.sch.add_worker("assistant", [("assistant", True)]) self.sch.ping(worker="assistant") self.assertEqual(self.sch.get_work(worker="assistant", assistant=True)["task_id"], "B") self.setTime(100000) # Here, lets say WORKER disconnects (doesnt ping) self.sch.ping(worker="assistant") self.sch.prune() self.setTime(200000) self.sch.ping(worker="assistant") self.sch.prune() self.assertEqual({"B"}, set(self.sch.task_list(RUNNING, ""))) self.assertEqual({"B"}, set(self.sch.task_list("", ""))) @mock.patch("luigi.scheduler.BatchNotifier") def test_batch_failure_emails(self, BatchNotifier): scheduler = Scheduler(batch_emails=True) scheduler.add_task(worker=WORKER, status=FAILED, task_id="T(a=5, b=6)", family="T", params={"a": "5", "b": "6"}, expl='"bad thing"') BatchNotifier().add_failure.assert_called_once_with( "T(a=5, b=6)", "T", {"a": "5", "b": "6"}, "bad thing", None, ) BatchNotifier().add_disable.assert_not_called() @mock.patch("luigi.scheduler.BatchNotifier") def test_send_batch_email_on_dump(self, BatchNotifier): scheduler = Scheduler(batch_emails=True) BatchNotifier().send_email.assert_not_called() scheduler.dump() BatchNotifier().send_email.assert_called_once_with() @mock.patch("luigi.scheduler.BatchNotifier") def test_do_not_send_batch_email_on_dump_without_batch_enabled(self, BatchNotifier): scheduler = Scheduler(batch_emails=False) scheduler.dump() BatchNotifier().send_email.assert_not_called() @mock.patch("luigi.scheduler.BatchNotifier") def test_handle_bad_expl_in_failure_emails(self, BatchNotifier): scheduler = Scheduler(batch_emails=True) scheduler.add_task(worker=WORKER, status=FAILED, task_id="T(a=5, b=6)", family="T", params={"a": "5", "b": "6"}, expl="bad thing") BatchNotifier().add_failure.assert_called_once_with( "T(a=5, b=6)", "T", {"a": "5", "b": "6"}, "bad thing", None, ) BatchNotifier().add_disable.assert_not_called() @mock.patch("luigi.scheduler.BatchNotifier") def test_scheduling_failure(self, BatchNotifier): scheduler = Scheduler(batch_emails=True) scheduler.announce_scheduling_failure(worker=WORKER, task_name="T(a=1, b=2)", family="T", params={"a": "1", "b": "2"}, expl="error", owners=("owner",)) BatchNotifier().add_scheduling_fail.assert_called_once_with("T(a=1, b=2)", "T", {"a": "1", "b": "2"}, "error", ("owner",)) @mock.patch("luigi.scheduler.BatchNotifier") def test_scheduling_failure_without_batcher(self, BatchNotifier): scheduler = Scheduler(batch_emails=False) scheduler.announce_scheduling_failure(worker=WORKER, task_name="T(a=1, b=2)", family="T", params={"a": "1", "b": "2"}, expl="error", owners=("owner",)) BatchNotifier().add_scheduling_fail.assert_not_called() @mock.patch("luigi.scheduler.BatchNotifier") def test_batch_failure_emails_with_task_batcher(self, BatchNotifier): scheduler = Scheduler(batch_emails=True) scheduler.add_task_batcher(worker=WORKER, task_family="T", batched_args=["a"]) scheduler.add_task(worker=WORKER, status=FAILED, task_id="T(a=5, b=6)", family="T", params={"a": "5", "b": "6"}, expl='"bad thing"') BatchNotifier().add_failure.assert_called_once_with( "T(a=5, b=6)", "T", {"b": "6"}, "bad thing", None, ) BatchNotifier().add_disable.assert_not_called() @mock.patch("luigi.scheduler.BatchNotifier") def test_scheduling_failure_with_task_batcher(self, BatchNotifier): scheduler = Scheduler(batch_emails=True) scheduler.add_task_batcher(worker=WORKER, task_family="T", batched_args=["a"]) scheduler.announce_scheduling_failure(worker=WORKER, task_name="T(a=1, b=2)", family="T", params={"a": "1", "b": "2"}, expl="error", owners=("owner",)) BatchNotifier().add_scheduling_fail.assert_called_once_with("T(a=1, b=2)", "T", {"b": "2"}, "error", ("owner",)) @mock.patch("luigi.scheduler.BatchNotifier") def test_batch_failure_email_with_owner(self, BatchNotifier): scheduler = Scheduler(batch_emails=True) scheduler.add_task( worker=WORKER, status=FAILED, task_id="T(a=5, b=6)", family="T", params={"a": "5", "b": "6"}, expl='"bad thing"', owners=["a@test.com", "b@test.com"], ) BatchNotifier().add_failure.assert_called_once_with( "T(a=5, b=6)", "T", {"a": "5", "b": "6"}, "bad thing", ["a@test.com", "b@test.com"], ) BatchNotifier().add_disable.assert_not_called() @mock.patch("luigi.scheduler.notifications") @mock.patch("luigi.scheduler.BatchNotifier") def test_batch_disable_emails(self, BatchNotifier, notifications): scheduler = Scheduler(batch_emails=True, retry_count=1) scheduler.add_task(worker=WORKER, status=FAILED, task_id="T(a=5, b=6)", family="T", params={"a": "5", "b": "6"}, expl='"bad thing"') BatchNotifier().add_failure.assert_called_once_with( "T(a=5, b=6)", "T", {"a": "5", "b": "6"}, "bad thing", None, ) BatchNotifier().add_disable.assert_called_once_with( "T(a=5, b=6)", "T", {"a": "5", "b": "6"}, None, ) notifications.send_error_email.assert_not_called() @mock.patch("luigi.scheduler.notifications") @mock.patch("luigi.scheduler.BatchNotifier") def test_batch_disable_email_with_owner(self, BatchNotifier, notifications): scheduler = Scheduler(batch_emails=True, retry_count=1) scheduler.add_task( worker=WORKER, status=FAILED, task_id="T(a=5, b=6)", family="T", params={"a": "5", "b": "6"}, expl='"bad thing"', owners=["a@test.com"] ) BatchNotifier().add_failure.assert_called_once_with( "T(a=5, b=6)", "T", {"a": "5", "b": "6"}, "bad thing", ["a@test.com"], ) BatchNotifier().add_disable.assert_called_once_with( "T(a=5, b=6)", "T", {"a": "5", "b": "6"}, ["a@test.com"], ) notifications.send_error_email.assert_not_called() @mock.patch("luigi.scheduler.notifications") @mock.patch("luigi.scheduler.BatchNotifier") def test_batch_disable_emails_with_task_batcher(self, BatchNotifier, notifications): scheduler = Scheduler(batch_emails=True, retry_count=1) scheduler.add_task_batcher(worker=WORKER, task_family="T", batched_args=["a"]) scheduler.add_task(worker=WORKER, status=FAILED, task_id="T(a=5, b=6)", family="T", params={"a": "5", "b": "6"}, expl='"bad thing"') BatchNotifier().add_failure.assert_called_once_with( "T(a=5, b=6)", "T", {"b": "6"}, "bad thing", None, ) BatchNotifier().add_disable.assert_called_once_with( "T(a=5, b=6)", "T", {"b": "6"}, None, ) notifications.send_error_email.assert_not_called() @mock.patch("luigi.scheduler.notifications") def test_send_normal_disable_email(self, notifications): scheduler = Scheduler(batch_emails=False, retry_count=1) notifications.send_error_email.assert_not_called() scheduler.add_task(worker=WORKER, status=FAILED, task_id="T(a=5, b=6)", family="T", params={"a": "5", "b": "6"}, expl='"bad thing"') self.assertEqual(1, notifications.send_error_email.call_count) @mock.patch("luigi.scheduler.BatchNotifier") def test_no_batch_notifier_without_batch_emails(self, BatchNotifier): Scheduler(batch_emails=False) BatchNotifier.assert_not_called() @mock.patch("luigi.scheduler.BatchNotifier") def test_update_batcher_on_prune(self, BatchNotifier): scheduler = Scheduler(batch_emails=True) BatchNotifier().update.assert_not_called() scheduler.prune() BatchNotifier().update.assert_called_once_with() def test_forgive_failures(self): # Try to build A but fails, forgive failures and will retry before 100s self.setTime(0) self.sch.add_task(worker=WORKER, task_id="A") self.assertEqual(self.sch.get_work(worker=WORKER)["task_id"], "A") self.sch.add_task(worker=WORKER, task_id="A", status=FAILED) self.setTime(1) self.assertEqual(self.sch.get_work(worker=WORKER)["task_id"], None) self.setTime(2) self.sch.forgive_failures(task_id="A") self.assertEqual(self.sch.get_work(worker=WORKER)["task_id"], "A") def test_you_can_forgive_failures_twice(self): # Try to build A but fails, forgive failures two times and will retry before 100s self.setTime(0) self.sch.add_task(worker=WORKER, task_id="A") self.assertEqual(self.sch.get_work(worker=WORKER)["task_id"], "A") self.sch.add_task(worker=WORKER, task_id="A", status=FAILED) self.setTime(1) self.assertEqual(self.sch.get_work(worker=WORKER)["task_id"], None) self.setTime(2) self.sch.forgive_failures(task_id="A") self.sch.forgive_failures(task_id="A") self.assertEqual(self.sch.get_work(worker=WORKER)["task_id"], "A") def test_mark_running_as_done_works(self): # Adding a task, it runs, then force-commiting it sends it to DONE self.setTime(0) self.sch.add_task(worker=WORKER, task_id="A") self.assertEqual(self.sch.get_work(worker=WORKER)["task_id"], "A") self.setTime(1) self.assertEqual({"A"}, set(self.sch.task_list(RUNNING, "").keys())) self.sch.mark_as_done(task_id="A") self.assertEqual({"A"}, set(self.sch.task_list(DONE, "").keys())) def test_mark_failed_as_done_works(self): # Adding a task, saying it failed, then force-commiting it sends it to DONE self.setTime(0) self.sch.add_task(worker=WORKER, task_id="A") self.assertEqual(self.sch.get_work(worker=WORKER)["task_id"], "A") self.sch.add_task(worker=WORKER, task_id="A", status=FAILED) self.setTime(1) self.assertEqual(set(), set(self.sch.task_list(RUNNING, "").keys())) self.assertEqual({"A"}, set(self.sch.task_list(FAILED, "").keys())) self.sch.mark_as_done(task_id="A") self.assertEqual({"A"}, set(self.sch.task_list(DONE, "").keys())) @mock.patch("luigi.metrics.NoMetricsCollector") def test_collector_metrics_on_task_started(self, MetricsCollector): from luigi.metrics import MetricsCollectors s = Scheduler(metrics_collector=MetricsCollectors.none) s.add_task(worker=WORKER, task_id="A", status=PENDING) s.get_work(worker=WORKER) task = s._state.get_task("A") MetricsCollector().handle_task_started.assert_called_once_with(task) @mock.patch("luigi.metrics.NoMetricsCollector") def test_collector_metrics_on_task_disabled(self, MetricsCollector): from luigi.metrics import MetricsCollectors s = Scheduler(metrics_collector=MetricsCollectors.none, retry_count=0) s.add_task(worker=WORKER, task_id="A", status=FAILED) task = s._state.get_task("A") MetricsCollector().handle_task_disabled.assert_called_once_with(task, s._config) @mock.patch("luigi.metrics.NoMetricsCollector") def test_collector_metrics_on_task_failed(self, MetricsCollector): from luigi.metrics import MetricsCollectors s = Scheduler(metrics_collector=MetricsCollectors.none) s.add_task(worker=WORKER, task_id="A", status=FAILED) task = s._state.get_task("A") MetricsCollector().handle_task_failed.assert_called_once_with(task) @mock.patch("luigi.metrics.NoMetricsCollector") def test_collector_metrics_on_task_done(self, MetricsCollector): from luigi.metrics import MetricsCollectors s = Scheduler(metrics_collector=MetricsCollectors.none) s.add_task(worker=WORKER, task_id="A", status=DONE) task = s._state.get_task("A") MetricsCollector().handle_task_done.assert_called_once_with(task) ================================================ FILE: test/scheduler_message_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import os import tempfile import time from helpers import LuigiTestCase, RunOnceTask import luigi import luigi.scheduler import luigi.worker def fast_worker(scheduler, **kwargs): kwargs.setdefault("ping_interval", 0.5) kwargs.setdefault("force_multiprocessing", True) return luigi.worker.Worker(scheduler=scheduler, **kwargs) class WriteMessageToFile(luigi.Task): path = luigi.Parameter() accepts_messages = True def output(self): return luigi.LocalTarget(self.path) def run(self): msg = "" time.sleep(1) if not self.scheduler_messages.empty(): msg = self.scheduler_messages.get().content with self.output().open("w") as f: f.write(msg + "\n") class SchedulerMessageTest(LuigiTestCase): def test_scheduler_methods(self): sch = luigi.scheduler.Scheduler(send_messages=True) sch.add_task(task_id="foo-task", worker="foo-worker") res = sch.send_scheduler_message("foo-worker", "foo-task", "message content") message_id = res["message_id"] self.assertTrue(len(message_id) > 0) self.assertIn("-", message_id) sch.add_scheduler_message_response("foo-task", message_id, "message response") res = sch.get_scheduler_message_response("foo-task", message_id) response = res["response"] self.assertEqual(response, "message response") def test_receive_messsage(self): sch = luigi.scheduler.Scheduler(send_messages=True) with fast_worker(sch) as w: with tempfile.NamedTemporaryFile() as tmp: if os.path.exists(tmp.name): os.remove(tmp.name) task = WriteMessageToFile(path=tmp.name) w.add(task) sch.send_scheduler_message(w._id, task.task_id, "test") w.run() self.assertTrue(os.path.exists(tmp.name)) with open(tmp.name, "r") as f: self.assertEqual(str(f.read()).strip(), "test") def test_receive_messages_disabled(self): sch = luigi.scheduler.Scheduler(send_messages=True) with fast_worker(sch, force_multiprocessing=False) as w: class MyTask(RunOnceTask): def run(self): self.had_queue = self.scheduler_messages is not None super(MyTask, self).run() task = MyTask() w.add(task) sch.send_scheduler_message(w._id, task.task_id, "test") w.run() self.assertFalse(task.had_queue) def test_send_messages_disabled(self): sch = luigi.scheduler.Scheduler(send_messages=False) with fast_worker(sch) as w: with tempfile.NamedTemporaryFile() as tmp: if os.path.exists(tmp.name): os.remove(tmp.name) task = WriteMessageToFile(path=tmp.name) w.add(task) sch.send_scheduler_message(w._id, task.task_id, "test") w.run() self.assertTrue(os.path.exists(tmp.name)) with open(tmp.name, "r") as f: self.assertEqual(str(f.read()).strip(), "") ================================================ FILE: test/scheduler_parameter_visibilities_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import json import time import server_test from helpers import LuigiTestCase, RunOnceTask import luigi import luigi.scheduler import luigi.worker from luigi.parameter import ParameterVisibility class SchedulerParameterVisibilitiesTest(LuigiTestCase): def test_task_with_deps(self): s = luigi.scheduler.Scheduler(send_messages=True) with luigi.worker.Worker(scheduler=s) as w: class DynamicTask(RunOnceTask): dynamic_public = luigi.Parameter(default="dynamic_public") dynamic_hidden = luigi.Parameter(default="dynamic_hidden", visibility=ParameterVisibility.HIDDEN) dynamic_private = luigi.Parameter(default="dynamic_private", visibility=ParameterVisibility.PRIVATE) class RequiredTask(RunOnceTask): required_public = luigi.Parameter(default="required_param") required_hidden = luigi.Parameter(default="required_hidden", visibility=ParameterVisibility.HIDDEN) required_private = luigi.Parameter(default="required_private", visibility=ParameterVisibility.PRIVATE) class Task(RunOnceTask): a = luigi.Parameter(default="a") b = luigi.Parameter(default="b", visibility=ParameterVisibility.HIDDEN) c = luigi.Parameter(default="c", visibility=ParameterVisibility.PRIVATE) d = luigi.Parameter(default="d", visibility=ParameterVisibility.PUBLIC) def requires(self): return required_task def run(self): yield dynamic_task dynamic_task = DynamicTask() required_task = RequiredTask() task = Task() w.add(task) w.run() time.sleep(1) task_deps = s.dep_graph(task_id=task.task_id) required_task_deps = s.dep_graph(task_id=required_task.task_id) dynamic_task_deps = s.dep_graph(task_id=dynamic_task.task_id) self.assertEqual("Task(a=a, d=d)", task_deps[task.task_id]["display_name"]) self.assertEqual("RequiredTask(required_public=required_param)", required_task_deps[required_task.task_id]["display_name"]) self.assertEqual("DynamicTask(dynamic_public=dynamic_public)", dynamic_task_deps[dynamic_task.task_id]["display_name"]) self.assertEqual({"a": "a", "d": "d"}, task_deps[task.task_id]["params"]) self.assertEqual({"required_public": "required_param"}, required_task_deps[required_task.task_id]["params"]) self.assertEqual({"dynamic_public": "dynamic_public"}, dynamic_task_deps[dynamic_task.task_id]["params"]) def test_public_and_hidden_params(self): s = luigi.scheduler.Scheduler(send_messages=True) with luigi.worker.Worker(scheduler=s) as w: class Task(RunOnceTask): a = luigi.Parameter(default="a") b = luigi.Parameter(default="b", visibility=ParameterVisibility.HIDDEN) c = luigi.Parameter(default="c", visibility=ParameterVisibility.PRIVATE) d = luigi.Parameter(default="d", visibility=ParameterVisibility.PUBLIC) task = Task() w.add(task) w.run() time.sleep(1) t = s._state.get_task(task.task_id) self.assertEqual({"b": "b"}, t.hidden_params) self.assertEqual({"a": "a", "d": "d"}, t.public_params) self.assertEqual({"a": 0, "b": 1, "d": 0}, t.param_visibilities) class Task(RunOnceTask): a = luigi.Parameter(default="a") b = luigi.Parameter(default="b", visibility=ParameterVisibility.HIDDEN) c = luigi.Parameter(default="c", visibility=ParameterVisibility.PRIVATE) d = luigi.Parameter(default="d", visibility=ParameterVisibility.PUBLIC) class RemoteSchedulerParameterVisibilitiesTest(server_test.ServerTestBase): def test_public_params(self): task = Task() luigi.build(tasks=[task], workers=2, scheduler_port=self.get_http_port()) time.sleep(1) response = self.fetch("/api/graph") body = response.body decoded = body.decode("utf8").replace("'", '"') data = json.loads(decoded) self.assertEqual({"a": "a", "d": "d"}, data["response"][task.task_id]["params"]) ================================================ FILE: test/scheduler_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import os import pickle import shutil import tempfile import time from multiprocessing import Process from helpers import unittest, with_config import luigi.configuration import luigi.scheduler import luigi.server from luigi.target import FileAlreadyExists class SchedulerIoTest(unittest.TestCase): def test_pretty_id_unicode(self): scheduler = luigi.scheduler.Scheduler() scheduler.add_task(worker="A", task_id="1", params={"foo": "\u2192bar"}) [task] = list(scheduler._state.get_active_tasks()) task.pretty_id def test_load_old_state(self): tasks = {} active_workers = {"Worker1": 1e9, "Worker2": time.time()} with tempfile.NamedTemporaryFile(delete=True) as fn: with open(fn.name, "wb") as fobj: state = (tasks, active_workers) pickle.dump(state, fobj) state = luigi.scheduler.SimpleTaskState(state_path=fn.name) state.load() self.assertEqual(set(state.get_worker_ids()), {"Worker1", "Worker2"}) def test_load_broken_state(self): with tempfile.NamedTemporaryFile(delete=True) as fn: with open(fn.name, "w") as fobj: print("b0rk", file=fobj) state = luigi.scheduler.SimpleTaskState(state_path=fn.name) state.load() # bad if this crashes self.assertEqual(list(state.get_worker_ids()), []) @with_config({"scheduler": {"retry_count": "44", "worker_disconnect_delay": "55"}}) def test_scheduler_with_config(self): scheduler = luigi.scheduler.Scheduler() self.assertEqual(44, scheduler._config.retry_count) self.assertEqual(55, scheduler._config.worker_disconnect_delay) # Override scheduler = luigi.scheduler.Scheduler(retry_count=66, worker_disconnect_delay=77) self.assertEqual(66, scheduler._config.retry_count) self.assertEqual(77, scheduler._config.worker_disconnect_delay) @with_config({"resources": {"a": "100", "b": "200"}}) def test_scheduler_with_resources(self): scheduler = luigi.scheduler.Scheduler() self.assertEqual({"a": 100, "b": 200}, scheduler._resources) @with_config({"scheduler": {"record_task_history": "True"}, "task_history": {"db_connection": "sqlite:////none/existing/path/hist.db"}}) def test_local_scheduler_task_history_status(self): ls = luigi.interface._WorkerSchedulerFactory().create_local_scheduler() self.assertEqual(False, ls._config.record_task_history) def test_load_recovers_tasks_index(self): scheduler = luigi.scheduler.Scheduler() scheduler.add_task(worker="A", task_id="1") scheduler.add_task(worker="B", task_id="2") scheduler.add_task(worker="C", task_id="3") scheduler.add_task(worker="D", task_id="4") self.assertEqual(scheduler.get_work(worker="A")["task_id"], "1") with tempfile.NamedTemporaryFile(delete=True) as fn: def reload_from_disk(scheduler): scheduler._state._state_path = fn.name scheduler.dump() scheduler = luigi.scheduler.Scheduler() scheduler._state._state_path = fn.name scheduler.load() return scheduler scheduler = reload_from_disk(scheduler=scheduler) self.assertEqual(scheduler.get_work(worker="B")["task_id"], "2") self.assertEqual(scheduler.get_work(worker="C")["task_id"], "3") scheduler = reload_from_disk(scheduler=scheduler) self.assertEqual(scheduler.get_work(worker="D")["task_id"], "4") def test_worker_prune_after_init(self): """ See https://github.com/spotify/luigi/pull/1019 """ worker = luigi.scheduler.Worker(123) class TmpCfg: def __init__(self): self.worker_disconnect_delay = 10 worker.prune(TmpCfg()) def test_get_empty_retry_policy(self): retry_policy = luigi.scheduler._get_empty_retry_policy() self.assertEqual(3, len(retry_policy)) self.assertEqual(["retry_count", "disable_hard_timeout", "disable_window"], list(retry_policy._asdict().keys())) self.assertEqual([None, None, None], list(retry_policy._asdict().values())) @with_config({"scheduler": {"retry_count": "9", "disable_hard_timeout": "99", "disable_window": "999"}}) def test_scheduler_get_retry_policy(self): s = luigi.scheduler.Scheduler() self.assertEqual(luigi.scheduler.RetryPolicy(9, 99, 999), s._config._get_retry_policy()) @with_config({"scheduler": {"retry_count": "9", "disable_hard_timeout": "99", "disable_window": "999"}}) def test_generate_retry_policy(self): s = luigi.scheduler.Scheduler() try: s._generate_retry_policy({"inexist_attr": True}) self.assertFalse(True, "'unexpected keyword argument' error must have been thrown") except TypeError: self.assertTrue(True) retry_policy = s._generate_retry_policy({}) self.assertEqual(luigi.scheduler.RetryPolicy(9, 99, 999), retry_policy) retry_policy = s._generate_retry_policy({"retry_count": 1}) self.assertEqual(luigi.scheduler.RetryPolicy(1, 99, 999), retry_policy) retry_policy = s._generate_retry_policy({"retry_count": 1, "disable_hard_timeout": 11, "disable_window": 111}) self.assertEqual(luigi.scheduler.RetryPolicy(1, 11, 111), retry_policy) @with_config({"scheduler": {"retry_count": "44"}}) def test_per_task_retry_policy(self): cps = luigi.scheduler.Scheduler() cps.add_task(worker="test_worker1", task_id="test_task_1", deps=["test_task_2", "test_task_3"]) tasks = list(cps._state.get_active_tasks()) self.assertEqual(3, len(tasks)) tasks = sorted(tasks, key=lambda x: x.id) task_1 = tasks[0] task_2 = tasks[1] task_3 = tasks[2] self.assertEqual("test_task_1", task_1.id) self.assertEqual("test_task_2", task_2.id) self.assertEqual("test_task_3", task_3.id) self.assertEqual(luigi.scheduler.RetryPolicy(44, 999999999, 3600), task_1.retry_policy) self.assertEqual(luigi.scheduler.RetryPolicy(44, 999999999, 3600), task_2.retry_policy) self.assertEqual(luigi.scheduler.RetryPolicy(44, 999999999, 3600), task_3.retry_policy) cps._state._tasks = {} cps.add_task( worker="test_worker2", task_id="test_task_4", deps=["test_task_5", "test_task_6"], retry_policy_dict=luigi.scheduler.RetryPolicy(99, 999, 9999)._asdict(), ) tasks = list(cps._state.get_active_tasks()) self.assertEqual(3, len(tasks)) tasks = sorted(tasks, key=lambda x: x.id) task_4 = tasks[0] task_5 = tasks[1] task_6 = tasks[2] self.assertEqual("test_task_4", task_4.id) self.assertEqual("test_task_5", task_5.id) self.assertEqual("test_task_6", task_6.id) self.assertEqual(luigi.scheduler.RetryPolicy(99, 999, 9999), task_4.retry_policy) self.assertEqual(luigi.scheduler.RetryPolicy(44, 999999999, 3600), task_5.retry_policy) self.assertEqual(luigi.scheduler.RetryPolicy(44, 999999999, 3600), task_6.retry_policy) cps._state._tasks = {} cps.add_task(worker="test_worker3", task_id="test_task_7", deps=["test_task_8", "test_task_9"]) cps.add_task(worker="test_worker3", task_id="test_task_8", retry_policy_dict=luigi.scheduler.RetryPolicy(99, 999, 9999)._asdict()) cps.add_task(worker="test_worker3", task_id="test_task_9", retry_policy_dict=luigi.scheduler.RetryPolicy(11, 111, 1111)._asdict()) tasks = list(cps._state.get_active_tasks()) self.assertEqual(3, len(tasks)) tasks = sorted(tasks, key=lambda x: x.id) task_7 = tasks[0] task_8 = tasks[1] task_9 = tasks[2] self.assertEqual("test_task_7", task_7.id) self.assertEqual("test_task_8", task_8.id) self.assertEqual("test_task_9", task_9.id) self.assertEqual(luigi.scheduler.RetryPolicy(44, 999999999, 3600), task_7.retry_policy) self.assertEqual(luigi.scheduler.RetryPolicy(99, 999, 9999), task_8.retry_policy) self.assertEqual(luigi.scheduler.RetryPolicy(11, 111, 1111), task_9.retry_policy) # Task 7 which is disable-failures 44 and its has_excessive_failures method returns False under 44 for i in range(43): task_7.add_failure() self.assertFalse(task_7.has_excessive_failures()) task_7.add_failure() self.assertTrue(task_7.has_excessive_failures()) # Task 8 which is disable-failures 99 and its has_excessive_failures method returns False under 44 for i in range(98): task_8.add_failure() self.assertFalse(task_8.has_excessive_failures()) task_8.add_failure() self.assertTrue(task_8.has_excessive_failures()) # Task 9 which is disable-failures 1 and its has_excessive_failures method returns False under 44 for i in range(10): task_9.add_failure() self.assertFalse(task_9.has_excessive_failures()) task_9.add_failure() self.assertTrue(task_9.has_excessive_failures()) @with_config({"scheduler": {"record_task_history": "true"}}) def test_has_task_history(self): cfg = luigi.configuration.get_config() with tempfile.NamedTemporaryFile(suffix=".db", delete=True) as fn: cfg.set("task_history", "db_connection", "sqlite:///" + fn.name) s = luigi.scheduler.Scheduler() self.assertTrue(s.has_task_history()) @with_config({"scheduler": {"record_task_history": "false"}}) def test_has_no_task_history(self): s = luigi.scheduler.Scheduler() self.assertFalse(s.has_task_history()) @with_config({"scheduler": {"pause_enabled": "false"}}) def test_pause_disabled(self): s = luigi.scheduler.Scheduler() self.assertFalse(s.is_pause_enabled()["enabled"]) self.assertFalse(s.is_paused()["paused"]) s.pause() self.assertFalse(s.is_paused()["paused"]) def test_default_metrics_collector(self): from luigi.metrics import MetricsCollector s = luigi.scheduler.Scheduler() scheduler_state = s._state collector = scheduler_state._metrics_collector self.assertTrue(isinstance(collector, MetricsCollector)) @with_config({"scheduler": {"metrics_collector": "datadog"}}) def test_datadog_metrics_collector(self): from luigi.contrib.datadog_metric import DatadogMetricsCollector s = luigi.scheduler.Scheduler() scheduler_state = s._state collector = scheduler_state._metrics_collector self.assertTrue(isinstance(collector, DatadogMetricsCollector)) @with_config({"scheduler": {"metrics_collector": "prometheus"}}) def test_prometheus_metrics_collector(self): from luigi.contrib.prometheus_metric import PrometheusMetricsCollector s = luigi.scheduler.Scheduler() scheduler_state = s._state collector = scheduler_state._metrics_collector self.assertTrue(isinstance(collector, PrometheusMetricsCollector)) @with_config({"scheduler": {"metrics_collector": "custom", "metrics_custom_import": "luigi.contrib.prometheus_metric.PrometheusMetricsCollector"}}) def test_custom_metrics_collector(self): from luigi.contrib.prometheus_metric import PrometheusMetricsCollector s = luigi.scheduler.Scheduler() scheduler_state = s._state collector = scheduler_state._metrics_collector self.assertTrue(isinstance(collector, PrometheusMetricsCollector)) class SchedulerWorkerTest(unittest.TestCase): def get_pending_ids(self, worker, state): return {task.id for task in worker.get_tasks(state, "PENDING")} def test_get_pending_tasks_with_many_done_tasks(self): sch = luigi.scheduler.Scheduler() sch.add_task(worker="NON_TRIVIAL", task_id="A", resources={"a": 1}) sch.add_task(worker="TRIVIAL", task_id="B", status="PENDING") sch.add_task(worker="TRIVIAL", task_id="C", status="DONE") sch.add_task(worker="TRIVIAL", task_id="D", status="DONE") scheduler_state = sch._state trivial_worker = scheduler_state.get_worker("TRIVIAL") self.assertEqual({"B"}, self.get_pending_ids(trivial_worker, scheduler_state)) non_trivial_worker = scheduler_state.get_worker("NON_TRIVIAL") self.assertEqual({"A"}, self.get_pending_ids(non_trivial_worker, scheduler_state)) class FailingOnDoubleRunTask(luigi.Task): time_to_check_secs = 1 time_to_run_secs = 2 output_dir = luigi.Parameter(default="") def __init__(self, *args, **kwargs): super(FailingOnDoubleRunTask, self).__init__(*args, **kwargs) self.file_name = os.path.join(self.output_dir, "AnyTask") def complete(self): time.sleep(self.time_to_check_secs) # e.g., establish connection exists = os.path.exists(self.file_name) time.sleep(self.time_to_check_secs) # e.g., close connection return exists def run(self): time.sleep(self.time_to_run_secs) if os.path.exists(self.file_name): raise FileAlreadyExists(self.file_name) open(self.file_name, "w").close() class StableDoneCooldownSecsTest(unittest.TestCase): def setUp(self): self.p = tempfile.mkdtemp() def tearDown(self): shutil.rmtree(self.p) def run_task(self): return luigi.build([FailingOnDoubleRunTask(output_dir=self.p)], detailed_summary=True, parallel_scheduling=True, parallel_scheduling_processes=2) @with_config({"worker": {"keep_alive": "false"}}) def get_second_run_result_on_double_run(self): server_process = Process(target=luigi.server.run) process = Process(target=self.run_task) try: # scheduler is started server_process.start() # first run is started process.start() time.sleep(FailingOnDoubleRunTask.time_to_run_secs + FailingOnDoubleRunTask.time_to_check_secs) # second run of the same task is started second_run_result = self.run_task() return second_run_result finally: process.join(1) server_process.terminate() server_process.join(1) @with_config({"scheduler": {"stable_done_cooldown_secs": "5"}}) def test_sending_same_task_twice_with_cooldown_does_not_lead_to_double_run(self): second_run_result = self.get_second_run_result_on_double_run() self.assertEqual(second_run_result.scheduling_succeeded, True) @with_config({"scheduler": {"stable_done_cooldown_secs": "0"}}) def test_sending_same_task_twice_without_cooldown_leads_to_double_run(self): second_run_result = self.get_second_run_result_on_double_run() self.assertEqual(second_run_result.scheduling_succeeded, False) ================================================ FILE: test/scheduler_visualisation_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import os import tempfile import time from helpers import RunOnceTask, unittest import luigi import luigi.notifications import luigi.scheduler import luigi.worker luigi.notifications.DEBUG = True tempdir = tempfile.mkdtemp() class DummyTask(luigi.Task): task_id = luigi.IntParameter() def run(self): f = self.output().open("w") f.close() def output(self): return luigi.LocalTarget(os.path.join(tempdir, str(self))) class FactorTask(luigi.Task): product = luigi.IntParameter() def requires(self): for factor in range(2, self.product): if self.product % factor == 0: yield FactorTask(factor) yield FactorTask(self.product // factor) return def run(self): f = self.output().open("w") f.close() def output(self): return luigi.LocalTarget(os.path.join(tempdir, "luigi_test_factor_%d" % self.product)) class BadReqTask(luigi.Task): succeed = luigi.BoolParameter() def requires(self): assert self.succeed yield BadReqTask(False) def run(self): pass def complete(self): return False class FailingTask(luigi.Task): task_namespace = __name__ task_id = luigi.IntParameter() def complete(self): return False def run(self): raise Exception("Error Message") class OddFibTask(luigi.Task): n = luigi.IntParameter() done = luigi.BoolParameter(default=True, significant=False) def requires(self): if self.n > 1: yield OddFibTask(self.n - 1, self.done) yield OddFibTask(self.n - 2, self.done) def complete(self): return self.n % 2 == 0 and self.done def run(self): assert False class SchedulerVisualisationTest(unittest.TestCase): def setUp(self): self.scheduler = luigi.scheduler.Scheduler() def tearDown(self): pass def _assert_complete(self, tasks): for t in tasks: self.assertTrue(t.complete()) def _build(self, tasks): with luigi.worker.Worker(scheduler=self.scheduler, worker_processes=1) as w: for t in tasks: w.add(t) w.run() def _remote(self): return self.scheduler def _test_run(self, workers): tasks = [DummyTask(i) for i in range(20)] self._build(tasks, workers=workers) self._assert_complete(tasks) def test_graph(self): start = time.time() tasks = [DummyTask(task_id=1), DummyTask(task_id=2)] self._build(tasks) self._assert_complete(tasks) end = time.time() remote = self._remote() graph = remote.graph() self.assertEqual(len(graph), 2) self.assertTrue(DummyTask(task_id=1).task_id in graph) d1 = graph[DummyTask(task_id=1).task_id] self.assertEqual(d1["status"], "DONE") self.assertEqual(d1["deps"], []) self.assertGreaterEqual(d1["start_time"], start) self.assertLessEqual(d1["start_time"], end) d2 = graph[DummyTask(task_id=2).task_id] self.assertEqual(d2["status"], "DONE") self.assertEqual(d2["deps"], []) self.assertGreaterEqual(d2["start_time"], start) self.assertLessEqual(d2["start_time"], end) def test_large_graph_truncate(self): class LinearTask(luigi.Task): idx = luigi.IntParameter() def requires(self): if self.idx > 0: yield LinearTask(self.idx - 1) def complete(self): return False root_task = LinearTask(100) self.scheduler = luigi.scheduler.Scheduler(max_graph_nodes=10) self._build([root_task]) graph = self.scheduler.dep_graph(root_task.task_id) self.assertEqual(10, len(graph)) expected_nodes = [LinearTask(i).task_id for i in range(100, 90, -1)] self.assertCountEqual(expected_nodes, graph) def test_large_inverse_graph_truncate(self): class LinearTask(luigi.Task): idx = luigi.IntParameter() def requires(self): if self.idx > 0: yield LinearTask(self.idx - 1) def complete(self): return False root_task = LinearTask(100) self.scheduler = luigi.scheduler.Scheduler(max_graph_nodes=10) self._build([root_task]) graph = self.scheduler.inverse_dep_graph(LinearTask(0).task_id) self.assertEqual(10, len(graph)) expected_nodes = [LinearTask(i).task_id for i in range(10)] self.assertCountEqual(expected_nodes, graph) def test_truncate_graph_with_full_levels(self): class BinaryTreeTask(RunOnceTask): idx = luigi.IntParameter() def requires(self): if self.idx < 100: return map(BinaryTreeTask, (self.idx * 2, self.idx * 2 + 1)) root_task = BinaryTreeTask(1) self.scheduler = luigi.scheduler.Scheduler(max_graph_nodes=10) self._build([root_task]) graph = self.scheduler.dep_graph(root_task.task_id) self.assertEqual(10, len(graph)) expected_nodes = [BinaryTreeTask(i).task_id for i in range(1, 11)] self.assertCountEqual(expected_nodes, graph) def test_truncate_graph_with_multiple_depths(self): class LinearTask(luigi.Task): idx = luigi.IntParameter() def requires(self): if self.idx > 0: yield LinearTask(self.idx - 1) yield LinearTask(0) def complete(self): return False root_task = LinearTask(100) self.scheduler = luigi.scheduler.Scheduler(max_graph_nodes=10) self._build([root_task]) graph = self.scheduler.dep_graph(root_task.task_id) self.assertEqual(10, len(graph)) expected_nodes = [LinearTask(i).task_id for i in range(100, 91, -1)] + [LinearTask(0).task_id] self.maxDiff = None self.assertCountEqual(expected_nodes, graph) def _assert_all_done(self, tasks): self._assert_all(tasks, "DONE") def _assert_all(self, tasks, status): for task in tasks.values(): self.assertEqual(task["status"], status) def test_dep_graph_single(self): self._build([FactorTask(1)]) remote = self._remote() dep_graph = remote.dep_graph(FactorTask(product=1).task_id) self.assertEqual(len(dep_graph), 1) self._assert_all_done(dep_graph) d1 = dep_graph.get(FactorTask(product=1).task_id) self.assertEqual(type(d1), type({})) self.assertEqual(d1["deps"], []) def test_dep_graph_not_found(self): self._build([FactorTask(1)]) remote = self._remote() dep_graph = remote.dep_graph(FactorTask(product=5).task_id) self.assertEqual(len(dep_graph), 0) def test_inverse_dep_graph_not_found(self): self._build([FactorTask(1)]) remote = self._remote() dep_graph = remote.inverse_dep_graph("FactorTask(product=5)") self.assertEqual(len(dep_graph), 0) def test_dep_graph_tree(self): self._build([FactorTask(30)]) remote = self._remote() dep_graph = remote.dep_graph(FactorTask(product=30).task_id) self.assertEqual(len(dep_graph), 5) self._assert_all_done(dep_graph) d30 = dep_graph[FactorTask(product=30).task_id] self.assertEqual(sorted(d30["deps"]), sorted([FactorTask(product=15).task_id, FactorTask(product=2).task_id])) d2 = dep_graph[FactorTask(product=2).task_id] self.assertEqual(sorted(d2["deps"]), []) d15 = dep_graph[FactorTask(product=15).task_id] self.assertEqual(sorted(d15["deps"]), sorted([FactorTask(product=3).task_id, FactorTask(product=5).task_id])) d3 = dep_graph[FactorTask(product=3).task_id] self.assertEqual(sorted(d3["deps"]), []) d5 = dep_graph[FactorTask(product=5).task_id] self.assertEqual(sorted(d5["deps"]), []) def test_dep_graph_missing_deps(self): self._build([BadReqTask(True)]) dep_graph = self._remote().dep_graph(BadReqTask(succeed=True).task_id) self.assertEqual(len(dep_graph), 2) suc = dep_graph[BadReqTask(succeed=True).task_id] self.assertEqual(suc["deps"], [BadReqTask(succeed=False).task_id]) fail = dep_graph[BadReqTask(succeed=False).task_id] self.assertEqual(fail["name"], "BadReqTask") self.assertEqual(fail["params"], {"succeed": "False"}) self.assertEqual(fail["status"], "UNKNOWN") def test_dep_graph_diamond(self): self._build([FactorTask(12)]) remote = self._remote() dep_graph = remote.dep_graph(FactorTask(product=12).task_id) self.assertEqual(len(dep_graph), 4) self._assert_all_done(dep_graph) d12 = dep_graph[FactorTask(product=12).task_id] self.assertEqual(sorted(d12["deps"]), sorted([FactorTask(product=2).task_id, FactorTask(product=6).task_id])) d6 = dep_graph[FactorTask(product=6).task_id] self.assertEqual(sorted(d6["deps"]), sorted([FactorTask(product=2).task_id, FactorTask(product=3).task_id])) d3 = dep_graph[FactorTask(product=3).task_id] self.assertEqual(sorted(d3["deps"]), []) d2 = dep_graph[FactorTask(product=2).task_id] self.assertEqual(sorted(d2["deps"]), []) def test_dep_graph_skip_done(self): task = OddFibTask(9) self._build([task]) remote = self._remote() task_id = task.task_id self.assertEqual(9, len(remote.dep_graph(task_id, include_done=True))) skip_done_graph = remote.dep_graph(task_id, include_done=False) self.assertEqual(5, len(skip_done_graph)) for task in skip_done_graph.values(): self.assertNotEqual("DONE", task["status"]) self.assertLess(len(task["deps"]), 2) def test_inverse_dep_graph_skip_done(self): self._build([OddFibTask(9, done=False)]) self._build([OddFibTask(9, done=True)]) remote = self._remote() task_id = OddFibTask(1).task_id self.assertEqual(9, len(remote.inverse_dep_graph(task_id, include_done=True))) skip_done_graph = remote.inverse_dep_graph(task_id, include_done=False) self.assertEqual(5, len(skip_done_graph)) for task in skip_done_graph.values(): self.assertNotEqual("DONE", task["status"]) self.assertLess(len(task["deps"]), 2) def test_task_list_single(self): self._build([FactorTask(7)]) remote = self._remote() tasks_done = remote.task_list("DONE", "") self.assertEqual(len(tasks_done), 1) self._assert_all_done(tasks_done) t7 = tasks_done.get(FactorTask(product=7).task_id) self.assertEqual(type(t7), type({})) self.assertEqual(remote.task_list("", ""), tasks_done) self.assertEqual(remote.task_list("FAILED", ""), {}) self.assertEqual(remote.task_list("PENDING", ""), {}) def test_dep_graph_root_has_display_name(self): root_task = FactorTask(12) self._build([root_task]) dep_graph = self._remote().dep_graph(root_task.task_id) self.assertEqual("FactorTask(product=12)", dep_graph[root_task.task_id]["display_name"]) def test_dep_graph_non_root_nodes_lack_display_name(self): root_task = FactorTask(12) self._build([root_task]) dep_graph = self._remote().dep_graph(root_task.task_id) for task_id, node in dep_graph.items(): if task_id != root_task.task_id: self.assertNotIn("display_name", node) def test_task_list_failed(self): self._build([FailingTask(8)]) remote = self._remote() failed = remote.task_list("FAILED", "") self.assertEqual(len(failed), 1) f8 = failed.get(FailingTask(task_id=8).task_id) self.assertEqual(f8["status"], "FAILED") self.assertEqual(remote.task_list("DONE", ""), {}) self.assertEqual(remote.task_list("PENDING", ""), {}) def test_task_list_upstream_status(self): class A(luigi.ExternalTask): def complete(self): return False class B(luigi.ExternalTask): def complete(self): return True class C(RunOnceTask): def requires(self): return [A(), B()] class F(luigi.Task): def complete(self): return False def run(self): raise Exception() class D(RunOnceTask): def requires(self): return [F()] class E(RunOnceTask): def requires(self): return [C(), D()] self._build([E()]) remote = self._remote() done = remote.task_list("DONE", "") self.assertEqual(len(done), 1) db = done.get(B().task_id) self.assertEqual(db["status"], "DONE") missing_input = remote.task_list("PENDING", "UPSTREAM_MISSING_INPUT") self.assertEqual(len(missing_input), 2) pa = missing_input.get(A().task_id) self.assertEqual(pa["status"], "PENDING") self.assertEqual(remote._upstream_status(A().task_id, {}), "UPSTREAM_MISSING_INPUT") pc = missing_input.get(C().task_id) self.assertEqual(pc["status"], "PENDING") self.assertEqual(remote._upstream_status(C().task_id, {}), "UPSTREAM_MISSING_INPUT") upstream_failed = remote.task_list("PENDING", "UPSTREAM_FAILED") self.assertEqual(len(upstream_failed), 2) pe = upstream_failed.get(E().task_id) self.assertEqual(pe["status"], "PENDING") self.assertEqual(remote._upstream_status(E().task_id, {}), "UPSTREAM_FAILED") pe = upstream_failed.get(D().task_id) self.assertEqual(pe["status"], "PENDING") self.assertEqual(remote._upstream_status(D().task_id, {}), "UPSTREAM_FAILED") pending = dict(missing_input) pending.update(upstream_failed) self.assertEqual(remote.task_list("PENDING", ""), pending) self.assertEqual(remote.task_list("PENDING", "UPSTREAM_RUNNING"), {}) failed = remote.task_list("FAILED", "") self.assertEqual(len(failed), 1) fd = failed.get(F().task_id) self.assertEqual(fd["status"], "FAILED") all = dict(pending) all.update(done) all.update(failed) self.assertEqual(remote.task_list("", ""), all) self.assertEqual(remote.task_list("RUNNING", ""), {}) def test_task_search(self): self._build([FactorTask(8)]) self._build([FailingTask(8)]) remote = self._remote() all_tasks = remote.task_search("Task") self.assertEqual(len(all_tasks), 2) self._assert_all(all_tasks["DONE"], "DONE") self._assert_all(all_tasks["FAILED"], "FAILED") def test_fetch_error(self): self._build([FailingTask(8)]) remote = self._remote() error = remote.fetch_error(FailingTask(task_id=8).task_id) self.assertEqual(error["taskId"], FailingTask(task_id=8).task_id) self.assertTrue("Error Message" in error["error"]) self.assertTrue("Runtime error" in error["error"]) self.assertTrue("Traceback" in error["error"]) def test_inverse_deps(self): class X(RunOnceTask): pass class Y(RunOnceTask): def requires(self): return [X()] class Z(RunOnceTask): id = luigi.IntParameter() def requires(self): return [Y()] class ZZ(RunOnceTask): def requires(self): return [Z(1), Z(2)] self._build([ZZ()]) dep_graph = self._remote().inverse_dep_graph(X().task_id) def assert_has_deps(task_id, deps): self.assertTrue(task_id in dep_graph, "%s not in dep_graph %s" % (task_id, dep_graph)) task = dep_graph[task_id] self.assertEqual(sorted(task["deps"]), sorted(deps), "%s does not have deps %s" % (task_id, deps)) assert_has_deps(X().task_id, [Y().task_id]) assert_has_deps(Y().task_id, [Z(id=1).task_id, Z(id=2).task_id]) assert_has_deps(Z(id=1).task_id, [ZZ().task_id]) assert_has_deps(Z(id=2).task_id, [ZZ().task_id]) assert_has_deps(ZZ().task_id, []) def test_simple_worker_list(self): class X(luigi.Task): def run(self): self._complete = True def complete(self): return getattr(self, "_complete", False) task_x = X() self._build([task_x]) workers = self._remote().worker_list() self.assertEqual(1, len(workers)) worker = workers[0] self.assertEqual(task_x.task_id, worker["first_task"]) self.assertEqual(0, worker["num_pending"]) self.assertEqual(0, worker["num_uniques"]) self.assertEqual(0, worker["num_running"]) self.assertEqual("active", worker["state"]) self.assertEqual(1, worker["workers"]) def test_worker_list_pending_uniques(self): class X(luigi.Task): def complete(self): return False class Y(X): def requires(self): return X() class Z(Y): pass w1 = luigi.worker.Worker(scheduler=self.scheduler, worker_processes=1) w2 = luigi.worker.Worker(scheduler=self.scheduler, worker_processes=1) w1.add(Y()) w2.add(Z()) workers = self._remote().worker_list() self.assertEqual(2, len(workers)) for worker in workers: self.assertEqual(2, worker["num_pending"]) self.assertEqual(1, worker["num_uniques"]) self.assertEqual(0, worker["num_running"]) def test_worker_list_running(self): class X(RunOnceTask): n = luigi.IntParameter() w = luigi.worker.Worker(worker_id="w", scheduler=self.scheduler, worker_processes=3) w.add(X(0)) w.add(X(1)) w.add(X(2)) w.add(X(3)) self.scheduler.get_work(worker="w") self.scheduler.get_work(worker="w") self.scheduler.get_work(worker="w") workers = self._remote().worker_list() self.assertEqual(1, len(workers)) worker = workers[0] self.assertEqual(3, worker["num_running"]) self.assertEqual(1, worker["num_pending"]) self.assertEqual(1, worker["num_uniques"]) def test_worker_list_disabled_worker(self): class X(RunOnceTask): pass with luigi.worker.Worker(worker_id="w", scheduler=self.scheduler) as w: w.add(X()) # workers = self._remote().worker_list() self.assertEqual(1, len(workers)) self.assertEqual("active", workers[0]["state"]) self.scheduler.disable_worker("w") workers = self._remote().worker_list() self.assertEqual(1, len(workers)) self.assertEqual(1, len(workers)) self.assertEqual("disabled", workers[0]["state"]) ================================================ FILE: test/server_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import functools import multiprocessing import os import shutil import signal import tempfile import time from urllib.parse import ParseResult, urlencode from urllib.parse import quote as urlquote import pytest import tornado.ioloop from helpers import skipOnTravisAndGithubActions, unittest from tornado.testing import AsyncHTTPTestCase import luigi.cmdline import luigi.rpc import luigi.server from luigi.configuration import get_config from luigi.scheduler import Scheduler try: from unittest import mock except ImportError: import mock def _is_running_from_main_thread(): """ Return true if we're the same thread as the one that created the Tornado IOLoop. In practice, the problem is that we get annoying intermittent failures because sometimes the KeepAliveThread jumps in and "disturbs" the intended flow of the test case. Worse, it fails in the terrible way that the KeepAliveThread is kept alive, bugging the execution of subsequent test casses. Oh, I so wish Tornado would explicitly say that you're acessing it from different threads and things will just not work. """ return tornado.ioloop.IOLoop.current(instance=False) class ServerTestBase(AsyncHTTPTestCase): def get_app(self): return luigi.server.app(Scheduler()) def setUp(self): super(ServerTestBase, self).setUp() self._old_fetch = luigi.rpc.RemoteScheduler._fetch def _fetch(obj, url, body, *args, **kwargs): if _is_running_from_main_thread(): body = urlencode(body).encode("utf-8") response = self.fetch(url, body=body, method="POST") if response.code >= 400: raise luigi.rpc.RPCError("Errror when connecting to remote scheduler") return response.body.decode("utf-8") luigi.rpc.RemoteScheduler._fetch = _fetch def tearDown(self): super(ServerTestBase, self).tearDown() luigi.rpc.RemoteScheduler._fetch = self._old_fetch class ServerTest(ServerTestBase): def setUp(self): super(ServerTest, self).setUp() get_config().remove_section("cors") self._default_cors = luigi.server.cors() get_config().set("cors", "enabled", "true") get_config().set("cors", "allow_any_origin", "true") get_config().set("cors", "allow_null_origin", "true") def tearDown(self): super(ServerTest, self).tearDown() get_config().remove_section("cors") def test_visualiser(self): page = self.fetch("/").body self.assertTrue(page.find(b"") != -1) def _test_404(self, path): response = self.fetch(path) self.assertEqual(response.code, 404) def test_404(self): self._test_404("/foo") def test_api_404(self): self._test_404("/api/foo") def test_root_redirect(self): response = self.fetch("/", follow_redirects=False) self.assertEqual(response.code, 302) self.assertEqual(response.headers["Location"], "static/visualiser/index.html") # assert that doesnt begin with leading slash ! def test_api_preflight_cors_headers(self): response = self.fetch("/api/graph", method="OPTIONS", headers={"Origin": "foo"}) headers = dict(response.headers) self.assertEqual(self._default_cors.allowed_headers, headers["Access-Control-Allow-Headers"]) self.assertEqual(self._default_cors.allowed_methods, headers["Access-Control-Allow-Methods"]) self.assertEqual("*", headers["Access-Control-Allow-Origin"]) self.assertEqual(str(self._default_cors.max_age), headers["Access-Control-Max-Age"]) self.assertIsNone(headers.get("Access-Control-Allow-Credentials")) self.assertIsNone(headers.get("Access-Control-Expose-Headers")) def test_api_preflight_cors_headers_all_response_headers(self): get_config().set("cors", "allow_credentials", "true") get_config().set("cors", "exposed_headers", "foo, bar") response = self.fetch("/api/graph", method="OPTIONS", headers={"Origin": "foo"}) headers = dict(response.headers) self.assertEqual(self._default_cors.allowed_headers, headers["Access-Control-Allow-Headers"]) self.assertEqual(self._default_cors.allowed_methods, headers["Access-Control-Allow-Methods"]) self.assertEqual("*", headers["Access-Control-Allow-Origin"]) self.assertEqual(str(self._default_cors.max_age), headers["Access-Control-Max-Age"]) self.assertEqual("true", headers["Access-Control-Allow-Credentials"]) self.assertEqual("foo, bar", headers["Access-Control-Expose-Headers"]) def test_api_preflight_cors_headers_null_origin(self): response = self.fetch("/api/graph", method="OPTIONS", headers={"Origin": "null"}) headers = dict(response.headers) self.assertEqual(self._default_cors.allowed_headers, headers["Access-Control-Allow-Headers"]) self.assertEqual(self._default_cors.allowed_methods, headers["Access-Control-Allow-Methods"]) self.assertEqual("null", headers["Access-Control-Allow-Origin"]) self.assertEqual(str(self._default_cors.max_age), headers["Access-Control-Max-Age"]) self.assertIsNone(headers.get("Access-Control-Allow-Credentials")) self.assertIsNone(headers.get("Access-Control-Expose-Headers")) def test_api_preflight_cors_headers_disallow_null(self): get_config().set("cors", "allow_null_origin", "false") response = self.fetch("/api/graph", method="OPTIONS", headers={"Origin": "null"}) headers = dict(response.headers) self.assertNotIn("Access-Control-Allow-Headers", headers) self.assertNotIn("Access-Control-Allow-Methods", headers) self.assertNotIn("Access-Control-Allow-Origin", headers) self.assertNotIn("Access-Control-Max-Age", headers) self.assertNotIn("Access-Control-Allow-Credentials", headers) self.assertNotIn("Access-Control-Expose-Headers", headers) def test_api_preflight_cors_headers_disallow_any(self): get_config().set("cors", "allow_any_origin", "false") get_config().set("cors", "allowed_origins", '["foo", "bar"]') response = self.fetch("/api/graph", method="OPTIONS", headers={"Origin": "foo"}) headers = dict(response.headers) self.assertEqual(self._default_cors.allowed_headers, headers["Access-Control-Allow-Headers"]) self.assertEqual(self._default_cors.allowed_methods, headers["Access-Control-Allow-Methods"]) self.assertEqual("foo", headers["Access-Control-Allow-Origin"]) self.assertEqual(str(self._default_cors.max_age), headers["Access-Control-Max-Age"]) self.assertIsNone(headers.get("Access-Control-Allow-Credentials")) self.assertIsNone(headers.get("Access-Control-Expose-Headers")) def test_api_preflight_cors_headers_disallow_any_no_matched_allowed_origins(self): get_config().set("cors", "allow_any_origin", "false") get_config().set("cors", "allowed_origins", '["foo", "bar"]') response = self.fetch("/api/graph", method="OPTIONS", headers={"Origin": "foobar"}) headers = dict(response.headers) self.assertNotIn("Access-Control-Allow-Headers", headers) self.assertNotIn("Access-Control-Allow-Methods", headers) self.assertNotIn("Access-Control-Allow-Origin", headers) self.assertNotIn("Access-Control-Max-Age", headers) self.assertNotIn("Access-Control-Allow-Credentials", headers) self.assertNotIn("Access-Control-Expose-Headers", headers) def test_api_preflight_cors_headers_disallow_any_no_allowed_origins(self): get_config().set("cors", "allow_any_origin", "false") response = self.fetch("/api/graph", method="OPTIONS", headers={"Origin": "foo"}) headers = dict(response.headers) self.assertNotIn("Access-Control-Allow-Headers", headers) self.assertNotIn("Access-Control-Allow-Methods", headers) self.assertNotIn("Access-Control-Allow-Origin", headers) self.assertNotIn("Access-Control-Max-Age", headers) self.assertNotIn("Access-Control-Allow-Credentials", headers) self.assertNotIn("Access-Control-Expose-Headers", headers) def test_api_preflight_cors_headers_disabled(self): get_config().set("cors", "enabled", "false") response = self.fetch("/api/graph", method="OPTIONS", headers={"Origin": "foo"}) headers = dict(response.headers) self.assertNotIn("Access-Control-Allow-Headers", headers) self.assertNotIn("Access-Control-Allow-Methods", headers) self.assertNotIn("Access-Control-Allow-Origin", headers) self.assertNotIn("Access-Control-Max-Age", headers) self.assertNotIn("Access-Control-Allow-Credentials", headers) self.assertNotIn("Access-Control-Expose-Headers", headers) def test_api_preflight_cors_headers_no_origin_header(self): response = self.fetch("/api/graph", method="OPTIONS") headers = dict(response.headers) self.assertNotIn("Access-Control-Allow-Headers", headers) self.assertNotIn("Access-Control-Allow-Methods", headers) self.assertNotIn("Access-Control-Allow-Origin", headers) self.assertNotIn("Access-Control-Max-Age", headers) self.assertNotIn("Access-Control-Allow-Credentials", headers) self.assertNotIn("Access-Control-Expose-Headers", headers) def test_api_cors_headers(self): response = self.fetch("/api/graph", headers={"Origin": "foo"}) headers = dict(response.headers) self.assertEqual("*", headers["Access-Control-Allow-Origin"]) def test_api_cors_headers_null_origin(self): response = self.fetch("/api/graph", headers={"Origin": "null"}) headers = dict(response.headers) self.assertEqual("null", headers["Access-Control-Allow-Origin"]) def test_api_cors_headers_disallow_null(self): get_config().set("cors", "allow_null_origin", "false") response = self.fetch("/api/graph", headers={"Origin": "null"}) headers = dict(response.headers) self.assertIsNone(headers.get("Access-Control-Allow-Origin")) def test_api_cors_headers_disallow_any(self): get_config().set("cors", "allow_any_origin", "false") get_config().set("cors", "allowed_origins", '["foo", "bar"]') response = self.fetch("/api/graph", headers={"Origin": "foo"}) headers = dict(response.headers) self.assertEqual("foo", headers["Access-Control-Allow-Origin"]) def test_api_cors_headers_disallow_any_no_matched_allowed_origins(self): get_config().set("cors", "allow_any_origin", "false") get_config().set("cors", "allowed_origins", '["foo", "bar"]') response = self.fetch("/api/graph", headers={"Origin": "foobar"}) headers = dict(response.headers) self.assertIsNone(headers.get("Access-Control-Allow-Origin")) def test_api_cors_headers_disallow_any_no_allowed_origins(self): get_config().set("cors", "allow_any_origin", "false") response = self.fetch("/api/graph", headers={"Origin": "foo"}) headers = dict(response.headers) self.assertIsNone(headers.get("Access-Control-Allow-Origin")) def test_api_cors_headers_disabled(self): get_config().set("cors", "enabled", "false") response = self.fetch("/api/graph", headers={"Origin": "foo"}) headers = dict(response.headers) self.assertIsNone(headers.get("Access-Control-Allow-Origin")) def test_api_cors_headers_no_origin_header(self): response = self.fetch("/api/graph") headers = dict(response.headers) self.assertIsNone(headers.get("Access-Control-Allow-Origin")) def test_api_allow_head_on_root(self): response = self.fetch("/", method="HEAD") self.assertEqual(response.code, 204) class _ServerTest(unittest.TestCase): """ Test to start and stop the server in a more "standard" way """ server_client_class = "To be defined by subclasses" def start_server(self): self._process = multiprocessing.Process(target=self.server_client.run_server) self._process.start() time.sleep(0.1) # wait for server to start self.sch = self.server_client.scheduler() self.sch._wait = lambda: None def stop_server(self): self._process.terminate() self._process.join(timeout=1) if self._process.is_alive(): os.kill(self._process.pid, signal.SIGKILL) def setUp(self): self.server_client = self.server_client_class() fd, state_path = tempfile.mkstemp(suffix=self.id()) os.close(fd) self.addCleanup(functools.partial(os.unlink, state_path)) luigi.configuration.get_config().set("scheduler", "state_path", state_path) self.start_server() def tearDown(self): self.stop_server() @skipOnTravisAndGithubActions("https://travis-ci.org/spotify/luigi/jobs/78315794") def test_ping(self): self.sch.ping(worker="xyz") @skipOnTravisAndGithubActions("https://travis-ci.org/spotify/luigi/jobs/78023665") def test_raw_ping(self): self.sch._request("/api/ping", {"worker": "xyz"}) @skipOnTravisAndGithubActions("https://travis-ci.org/spotify/luigi/jobs/78023665") def test_raw_ping_extended(self): self.sch._request("/api/ping", {"worker": "xyz", "foo": "bar"}) @skipOnTravisAndGithubActions("https://travis-ci.org/spotify/luigi/jobs/166833694") def test_404(self): with self.assertRaises(luigi.rpc.RPCError): self.sch._request("/api/fdsfds", {"dummy": 1}) @skipOnTravisAndGithubActions("https://travis-ci.org/spotify/luigi/jobs/72953884") def test_save_state(self): self.sch.add_task(worker="X", task_id="B", deps=("A",)) self.sch.add_task(worker="X", task_id="A") self.assertEqual(self.sch.get_work(worker="X")["task_id"], "A") self.stop_server() self.start_server() work = self.sch.get_work(worker="X")["running_tasks"][0] self.assertEqual(work["task_id"], "A") @pytest.mark.unixsocket class UNIXServerTest(_ServerTest): class ServerClient: def __init__(self): self.tempdir = tempfile.mkdtemp() self.unix_socket = os.path.join(self.tempdir, "luigid.sock") def run_server(self): luigi.server.run(unix_socket=self.unix_socket) def scheduler(self): url = ParseResult( scheme="http+unix", netloc=urlquote(self.unix_socket, safe=""), path="", params="", query="", fragment="", ).geturl() return luigi.rpc.RemoteScheduler(url) server_client_class = ServerClient def tearDown(self): super(UNIXServerTest, self).tearDown() shutil.rmtree(self.server_client.tempdir) class INETServerClient: def __init__(self): # Just some port self.port = 8083 def scheduler(self): return luigi.rpc.RemoteScheduler("http://localhost:" + str(self.port)) class _INETServerTest(_ServerTest): # HACK: nose ignores class whose name starts with underscore # see: https://github.com/nose-devs/nose/blob/6f9dada1a5593b2365859bab92c7d1e468b64b7b/nose/selector.py#L72 # This hack affects derived classes of this class e.g. INETProcessServerTest, INETLuigidServerTest, INETLuigidDaemonServerTest. __test__ = False def test_with_cmdline(self): """ Test to run against the server as a normal luigi invocation does """ params = ["Task", "--scheduler-port", str(self.server_client.port), "--no-lock"] self.assertTrue(luigi.interface.run(params)) class INETProcessServerTest(_INETServerTest): __test__ = True class ServerClient(INETServerClient): def run_server(self): luigi.server.run(api_port=self.port, address="127.0.0.1") server_client_class = ServerClient class INETURLLibServerTest(INETProcessServerTest): @mock.patch.object(luigi.rpc, "HAS_REQUESTS", False) def start_server(self, *args, **kwargs): super(INETURLLibServerTest, self).start_server(*args, **kwargs) @skipOnTravisAndGithubActions("https://travis-ci.org/spotify/luigi/jobs/81022689") def patching_test(self): """ Check that HAS_REQUESTS patching is meaningful """ fetcher1 = luigi.rpc.RemoteScheduler()._fetcher with mock.patch.object(luigi.rpc, "HAS_REQUESTS", False): fetcher2 = luigi.rpc.RemoteScheduler()._fetcher self.assertNotEqual(fetcher1.__class__, fetcher2.__class__) class INETLuigidServerTest(_INETServerTest): __test__ = True class ServerClient(INETServerClient): def run_server(self): # I first tried to things like "subprocess.call(['luigid', ...]), # But it ended up to be a total mess getting the cleanup to work # unfortunately. luigi.cmdline.luigid(["--port", str(self.port)]) server_client_class = ServerClient class INETLuigidDaemonServerTest(_INETServerTest): __test__ = True class ServerClient(INETServerClient): def __init__(self): super(INETLuigidDaemonServerTest.ServerClient, self).__init__() self.tempdir = tempfile.mkdtemp() @mock.patch("daemon.DaemonContext") def run_server(self, daemon_context): luigi.cmdline.luigid( [ "--port", str(self.port), "--background", # This makes it a daemon "--logdir", self.tempdir, "--pidfile", os.path.join(self.tempdir, "luigid.pid"), ] ) def tearDown(self): super(INETLuigidDaemonServerTest, self).tearDown() shutil.rmtree(self.server_client.tempdir) server_client_class = ServerClient class MetricsHandlerTest(unittest.TestCase): def setUp(self): self.mock_scheduler = mock.MagicMock() self.handler = luigi.server.MetricsHandler(tornado.web.Application(), mock.MagicMock(), scheduler=self.mock_scheduler) def test_initialize(self): self.assertIs(self.handler._scheduler, self.mock_scheduler) def test_get(self): mock_metrics = mock.MagicMock() self.mock_scheduler._state._metrics_collector.generate_latest.return_value = mock_metrics with mock.patch.object(self.handler, "write") as patched_write: self.handler.get() patched_write.assert_called_once_with(mock_metrics) self.mock_scheduler._state._metrics_collector.configure_http_handler.assert_called_once_with(self.handler) def test_get_no_metrics(self): self.mock_scheduler._state._metrics_collector.generate_latest.return_value = None with mock.patch.object(self.handler, "write") as patched_write: self.handler.get() patched_write.assert_not_called() class FromUtcTest(unittest.TestCase): def test_with_microseconds(self): """Test parsing UTC time string with microseconds""" result = luigi.server.from_utc("2021-01-15 10:30:45.123456") self.assertIsInstance(result, int) def test_without_microseconds(self): """Test parsing UTC time string without microseconds""" result = luigi.server.from_utc("2021-01-15 10:30:45") self.assertIsInstance(result, int) def test_with_custom_format(self): """Test parsing with custom format""" result = luigi.server.from_utc("01/15/2021", fmt="%m/%d/%Y") self.assertIsInstance(result, int) def test_invalid_format_raises_error(self): """Test that invalid format raises ValueError""" with self.assertRaises(ValueError): luigi.server.from_utc("invalid-date") def test_custom_format_mismatch_raises_error(self): """Test that mismatched custom format raises ValueError""" with self.assertRaises(ValueError): luigi.server.from_utc("2021-01-15", fmt="%m/%d/%Y") ================================================ FILE: test/set_task_name_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from helpers import unittest import luigi def create_class(cls_name): class NewTask(luigi.WrapperTask): pass NewTask.__name__ = cls_name return NewTask create_class("MyNewTask") class SetTaskNameTest(unittest.TestCase): """I accidentally introduced an issue in this commit: https://github.com/spotify/luigi/commit/6330e9d0332e6152996292a39c42f752b9288c96 This causes tasks not to get exposed if they change name later. Adding a unit test to resolve the issue.""" def test_set_task_name(self): luigi.run(["--local-scheduler", "--no-lock", "MyNewTask"]) ================================================ FILE: test/setup_logging_test.py ================================================ from helpers import unittest from luigi.configuration import LuigiConfigParser, LuigiTomlParser, get_config from luigi.setup_logging import DaemonLogging, InterfaceLogging class TestDaemonLogging(unittest.TestCase): cls = DaemonLogging def setUp(self): self.cls._configured = False def tearDown(self): self.cls._configured = False self.cls.config = get_config() def test_cli(self): opts = type("opts", (), {}) opts.background = True result = self.cls._cli(opts) self.assertTrue(result) opts.background = False opts.logdir = "./tests/" result = self.cls._cli(opts) self.assertTrue(result) opts.background = False opts.logdir = False result = self.cls._cli(opts) self.assertFalse(result) def test_section(self): self.cls.config = { "logging": { "version": 1, "disable_existing_loggers": False, "formatters": { "mockformatter": { "format": "{levelname}: {message}", "style": "{", "datefmt": "%Y-%m-%d %H:%M:%S", }, }, "handlers": { "mockhandler": { "class": "logging.StreamHandler", "level": "INFO", "formatter": "mockformatter", }, }, "loggers": { "mocklogger": { "handlers": ("mockhandler",), "level": "INFO", "disabled": False, "propagate": False, }, }, }, } result = self.cls._section(None) self.assertTrue(result) self.cls.config = LuigiTomlParser() self.cls.config.read(["./test/testconfig/luigi_logging.toml"]) result = self.cls._section(None) self.assertTrue(result) self.cls.config = {} result = self.cls._section(None) self.assertFalse(result) def test_section_cfg(self): self.cls.config = LuigiConfigParser.instance() result = self.cls._section(None) self.assertFalse(result) def test_cfg(self): self.cls.config = LuigiTomlParser() self.cls.config.data = {} result = self.cls._conf(None) self.assertFalse(result) self.cls.config.data = {"core": {"logging_conf_file": "./blah"}} with self.assertRaises(OSError): self.cls._conf(None) self.cls.config.data = { "core": { "logging_conf_file": "./test/testconfig/logging.cfg", } } result = self.cls._conf(None) self.assertTrue(result) def test_default(self): result = self.cls._default(None) self.assertTrue(result) class TestInterfaceLogging(TestDaemonLogging): cls = InterfaceLogging def test_cli(self): opts = type("opts", (), {}) result = self.cls._cli(opts) self.assertFalse(result) # test_section inherited from TestDaemonLogging def test_cfg(self): self.cls.config = LuigiTomlParser() self.cls.config.data = {} opts = type("opts", (), {}) opts.logging_conf_file = "" result = self.cls._conf(opts) self.assertFalse(result) opts.logging_conf_file = "./blah" with self.assertRaises(OSError): self.cls._conf(opts) opts.logging_conf_file = "./test/testconfig/logging.cfg" result = self.cls._conf(opts) self.assertTrue(result) def test_default(self): opts = type("opts", (), {}) opts.log_level = "INFO" result = self.cls._default(opts) self.assertTrue(result) class PatchedLogging(InterfaceLogging): @classmethod def _cli(cls, *args): cls.calls.append("_cli") return "_cli" not in cls.patched @classmethod def _conf(cls, *args): cls.calls.append("_conf") return "_conf" not in cls.patched @classmethod def _section(cls, *args): cls.calls.append("_section") return "_section" not in cls.patched @classmethod def _default(cls, *args): cls.calls.append("_default") return "_default" not in cls.patched class TestSetup(unittest.TestCase): def setUp(self): self.opts = type("opts", (), {}) self.cls = PatchedLogging self.cls.calls = [] self.cls.config = LuigiTomlParser() self.cls._configured = False self.cls.patched = "_cli", "_conf", "_section", "_default" def tearDown(self): self.cls.config = get_config() def test_configured(self): self.cls._configured = True result = self.cls.setup(self.opts) self.assertEqual(self.cls.calls, []) self.assertFalse(result) def test_disabled(self): self.cls.config.data = {"core": {"no_configure_logging": True}} result = self.cls.setup(self.opts) self.assertEqual(self.cls.calls, []) self.assertFalse(result) def test_order(self): self.cls.setup(self.opts) self.assertEqual(self.cls.calls, ["_cli", "_conf", "_section", "_default"]) def test_cli(self): self.cls.patched = () result = self.cls.setup(self.opts) self.assertTrue(result) self.assertEqual(self.cls.calls, ["_cli"]) def test_conf(self): self.cls.patched = ("_cli",) result = self.cls.setup(self.opts) self.assertTrue(result) self.assertEqual(self.cls.calls, ["_cli", "_conf"]) def test_section(self): self.cls.patched = ("_cli", "_conf") result = self.cls.setup(self.opts) self.assertTrue(result) self.assertEqual(self.cls.calls, ["_cli", "_conf", "_section"]) def test_default(self): self.cls.patched = ("_cli", "_conf", "_section") result = self.cls.setup(self.opts) self.assertTrue(result) self.assertEqual(self.cls.calls, ["_cli", "_conf", "_section", "_default"]) ================================================ FILE: test/simulate_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import os import tempfile from multiprocessing import Process from helpers import unittest import luigi from luigi.contrib.simulate import RunAnywayTarget def temp_dir(): return os.path.join(tempfile.gettempdir(), "luigi-simulate") def is_writable(): d = temp_dir() fn = os.path.join(d, "luigi-simulate-write-test") exists = True try: try: os.makedirs(d) except OSError: pass open(fn, "w").close() os.remove(fn) except BaseException: exists = False return unittest.skipIf(not exists, "Can't write to temporary directory") class TaskA(luigi.Task): i = luigi.IntParameter(default=0) def output(self): return RunAnywayTarget(self) def run(self): fn = os.path.join(temp_dir(), "luigi-simulate-test.tmp") try: os.makedirs(os.path.dirname(fn)) except OSError: pass with open(fn, "a") as f: f.write("{0}={1}\n".format(self.__class__.__name__, self.i)) self.output().done() class TaskB(TaskA): def requires(self): return TaskA(i=10) class TaskC(TaskA): def requires(self): return TaskA(i=5) class TaskD(TaskA): def requires(self): return [TaskB(), TaskC(), TaskA(i=20)] class TaskWrap(luigi.WrapperTask): def requires(self): return [TaskA(), TaskD()] def reset(): # Force tasks to be executed again (because multiple pipelines are executed inside of the same process) t = TaskA().output() with t.unique.get_lock(): t.unique.value = 0 class RunAnywayTargetTest(unittest.TestCase): @is_writable() def test_output(self): reset() fn = os.path.join(temp_dir(), "luigi-simulate-test.tmp") luigi.build([TaskWrap()], local_scheduler=True) with open(fn, "r") as f: data = f.read().strip().split("\n") data.sort() reference = ["TaskA=0", "TaskA=10", "TaskA=20", "TaskA=5", "TaskB=0", "TaskC=0", "TaskD=0"] reference.sort() os.remove(fn) self.assertEqual(data, reference) @is_writable() def test_output_again(self): # Running the test in another process because the PID is used to determine if the target exists p = Process(target=self.test_output) p.start() p.join() ================================================ FILE: test/subtask_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import abc from helpers import unittest import luigi class AbstractTask(luigi.Task): k = luigi.IntParameter() @property @abc.abstractmethod def foo(self): raise NotImplementedError @abc.abstractmethod def helper_function(self): raise NotImplementedError def run(self): return ",".join([self.foo, self.helper_function()]) class Implementation(AbstractTask): @property def foo(self): return "bar" def helper_function(self): return "hello" * self.k class AbstractSubclassTest(unittest.TestCase): def test_instantiate_abstract(self): def try_instantiate(): AbstractTask(k=1) self.assertRaises(TypeError, try_instantiate) def test_instantiate(self): self.assertEqual("bar,hellohello", Implementation(k=2).run()) ================================================ FILE: test/target_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import random import re from helpers import skipOnTravisAndGithubActions, unittest from mock import Mock import luigi.format import luigi.target class TestException(Exception): pass class TargetTest(unittest.TestCase): def test_cannot_instantiate(self): def instantiate_target(): luigi.target.Target() self.assertRaises(TypeError, instantiate_target) def test_abstract_subclass(self): class ExistsLessTarget(luigi.target.Target): pass def instantiate_target(): ExistsLessTarget() self.assertRaises(TypeError, instantiate_target) def test_instantiate_subclass(self): class GoodTarget(luigi.target.Target): def exists(self): return True def open(self, mode): return None GoodTarget() class FileSystemTargetTestMixin: """All Target that take bytes (python2: str) should pass those tests. In addition, a test to verify the method `exists`should be added """ def create_target(self, format=None): raise NotImplementedError() def assertCleanUp(self, tmp_path=""): pass def test_atomicity(self): target = self.create_target() fobj = target.open("w") self.assertFalse(target.exists()) fobj.close() self.assertTrue(target.exists()) def test_readback(self): target = self.create_target() origdata = "lol\n" fobj = target.open("w") fobj.write(origdata) fobj.close() fobj = target.open("r") data = fobj.read() self.assertEqual(origdata, data) def test_unicode_obj(self): target = self.create_target() origdata = "lol\n" fobj = target.open("w") fobj.write(origdata) fobj.close() fobj = target.open("r") data = fobj.read() self.assertEqual(origdata, data) def test_with_close(self): target = self.create_target() with target.open("w") as fobj: tp = getattr(fobj, "tmp_path", "") fobj.write("hej\n") self.assertCleanUp(tp) self.assertTrue(target.exists()) def test_with_exception(self): target = self.create_target() a = {} def foo(): with target.open("w") as fobj: fobj.write("hej\n") a["tp"] = getattr(fobj, "tmp_path", "") raise TestException("Test triggered exception") self.assertRaises(TestException, foo) self.assertCleanUp(a["tp"]) self.assertFalse(target.exists()) def test_del(self): t = self.create_target() p = t.open("w") print("test", file=p) tp = getattr(p, "tmp_path", "") del p self.assertCleanUp(tp) self.assertFalse(t.exists()) def test_write_cleanup_no_close(self): t = self.create_target() def context(): f = t.open("w") f.write("stuff") return getattr(f, "tmp_path", "") tp = context() import gc gc.collect() # force garbage collection of f variable self.assertCleanUp(tp) self.assertFalse(t.exists()) def test_text(self): t = self.create_target(luigi.format.UTF8) a = "我éçф" with t.open("w") as f: f.write(a) with t.open("r") as f: b = f.read() self.assertEqual(a, b) def test_del_with_Text(self): t = self.create_target(luigi.format.UTF8) p = t.open("w") print("test", file=p) tp = getattr(p, "tmp_path", "") del p self.assertCleanUp(tp) self.assertFalse(t.exists()) def test_format_injection(self): class CustomFormat(luigi.format.Format): def pipe_reader(self, input_pipe): input_pipe.foo = "custom read property" return input_pipe def pipe_writer(self, output_pipe): output_pipe.foo = "custom write property" return output_pipe t = self.create_target(CustomFormat()) with t.open("w") as f: self.assertEqual(f.foo, "custom write property") with t.open("r") as f: self.assertEqual(f.foo, "custom read property") @skipOnTravisAndGithubActions("https://travis-ci.org/spotify/luigi/jobs/73693470") def test_binary_write(self): t = self.create_target(luigi.format.Nop) with t.open("w") as f: f.write(b"a\xf2\xf3\r\nfd") with t.open("r") as f: c = f.read() self.assertEqual(c, b"a\xf2\xf3\r\nfd") def test_writelines(self): t = self.create_target() with t.open("w") as f: f.writelines( [ "a\n", "b\n", "c\n", ] ) with t.open("r") as f: c = f.read() self.assertEqual(c, "a\nb\nc\n") def test_read_iterator(self): t = self.create_target() with t.open("w") as f: f.write("a\nb\nc\n") c = [] with t.open("r") as f: for x in f: c.append(x) self.assertEqual(c, ["a\n", "b\n", "c\n"]) def test_gzip(self): t = self.create_target(luigi.format.Gzip) p = t.open("w") test_data = b"test" p.write(test_data) tp = getattr(p, "tmp_path", "") self.assertFalse(t.exists()) p.close() self.assertCleanUp(tp) self.assertTrue(t.exists()) def test_gzip_works_and_cleans_up(self): t = self.create_target(luigi.format.Gzip) test_data = b"123testing" with t.open("w") as f: tp = getattr(f, "tmp_path", "") f.write(test_data) self.assertCleanUp(tp) with t.open() as f: result = f.read() self.assertEqual(test_data, result) def test_move_on_fs(self): # We're cheating and retrieving the fs from target. # TODO: maybe move to "filesystem_test.py" or something t = self.create_target() other_path = t.path + "-" + str(random.randint(0, 999999999)) t._touchz() fs = t.fs self.assertTrue(t.exists()) fs.move(t.path, other_path) self.assertFalse(t.exists()) def test_rename_dont_move_on_fs(self): # We're cheating and retrieving the fs from target. # TODO: maybe move to "filesystem_test.py" or something t = self.create_target() other_path = t.path + "-" + str(random.randint(0, 999999999)) t._touchz() fs = t.fs self.assertTrue(t.exists()) fs.rename_dont_move(t.path, other_path) self.assertFalse(t.exists()) self.assertRaises(luigi.target.FileAlreadyExists, lambda: fs.rename_dont_move(t.path, other_path)) class TemporaryPathTest(unittest.TestCase): def setUp(self): super(TemporaryPathTest, self).setUp() self.fs = Mock() class MyFileSystemTarget(luigi.target.FileSystemTarget): open = None # Must be implemented due to abc stuff fs = self.fs self.target_cls = MyFileSystemTarget def test_temporary_path_files(self): target_outer = self.target_cls("/tmp/notreal.xls") target_inner = self.target_cls("/tmp/blah.txt") class MyException(Exception): pass orig_ex = MyException() try: with target_outer.temporary_path() as tmp_path_outer: self.assertIn("notreal", tmp_path_outer) with target_inner.temporary_path() as tmp_path_inner: self.assertIn("blah", tmp_path_inner) with target_inner.temporary_path() as tmp_path_inner_2: self.assertNotEqual(tmp_path_inner, tmp_path_inner_2) self.fs.rename_dont_move.assert_called_once_with(tmp_path_inner_2, target_inner.path) self.fs.rename_dont_move.assert_called_with(tmp_path_inner, target_inner.path) self.assertEqual(self.fs.rename_dont_move.call_count, 2) raise orig_ex except MyException as ex: self.assertIs(ex, orig_ex) else: assert False self.assertEqual(self.fs.rename_dont_move.call_count, 2) def test_temporary_path_directory(self): target_slash = self.target_cls("/tmp/dir/") target_noslash = self.target_cls("/tmp/dir") with target_slash.temporary_path() as tmp_path: assert re.match(r"/tmp/dir-luigi-tmp-\d{10}/", tmp_path) self.fs.rename_dont_move.assert_called_once_with(tmp_path, target_slash.path) with target_noslash.temporary_path() as tmp_path: assert re.match(r"/tmp/dir-luigi-tmp-\d{10}", tmp_path) self.fs.rename_dont_move.assert_called_with(tmp_path, target_noslash.path) def test_windowsish_dir(self): target = self.target_cls(r"""C:\my\folder""" + "\\") pattern = r"""C:\\my\\folder-luigi-tmp-\d{10}""" + r"\\" with target.temporary_path() as tmp_path: assert re.match(pattern, tmp_path) self.fs.rename_dont_move.assert_called_once_with(tmp_path, target.path) def test_hadoopish_dir(self): target = self.target_cls(r"""hdfs:///user/arash/myfile.uids""") with target.temporary_path() as tmp_path: assert re.match(r"""hdfs:///user/arash/myfile.uids-luigi-tmp-\d{10}""", tmp_path) self.fs.rename_dont_move.assert_called_once_with(tmp_path, target.path) def test_creates_dir_for_file(self): target = self.target_cls("/my/file/is/awesome.txt") with target.temporary_path(): self.fs.mkdir.assert_called_once_with("/my/file/is", parents=True, raise_if_exists=False) def test_creates_dir_for_dir(self): target = self.target_cls("/my/dir/is/awesome/") with target.temporary_path(): self.fs.mkdir.assert_called_once_with("/my/dir/is", parents=True, raise_if_exists=False) def test_file_in_current_dir(self): target = self.target_cls("foo.txt") with target.temporary_path() as tmp_path: self.fs.mkdir.assert_not_called() # there is no dir to create self.fs.rename_dont_move.assert_called_once_with(tmp_path, target.path) ================================================ FILE: test/task_bulk_complete_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2016 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from helpers import unittest from luigi import Parameter, Task from luigi.task import MixinNaiveBulkComplete COMPLETE_TASKS = ["A", "B", "C"] class MockTask(MixinNaiveBulkComplete, Task): param_a = Parameter() param_b = Parameter(default="Not Mandatory") def complete(self): return self.param_a in COMPLETE_TASKS class MixinNaiveBulkCompleteTest(unittest.TestCase): """ Test that the MixinNaiveBulkComplete can handle input as - iterable of parameters (for single param tasks) - iterable of parameter tuples (for multi param tasks) - iterable of parameter dicts (for multi param tasks) """ def test_single_arg_list(self): single_arg_list = ["A", "B", "x"] expected_single_arg_list = {p for p in single_arg_list if p in COMPLETE_TASKS} self.assertEqual(expected_single_arg_list, set(MockTask.bulk_complete(single_arg_list))) def test_multiple_arg_tuple(self): multiple_arg_tuple = (("A", "1"), ("B", "2"), ("X", "3"), ("C", "2")) expected_multiple_arg_tuple = {p for p in multiple_arg_tuple if p[0] in COMPLETE_TASKS} self.assertEqual(expected_multiple_arg_tuple, set(MockTask.bulk_complete(multiple_arg_tuple))) def test_multiple_arg_dict(self): multiple_arg_dict = ({"param_a": "X", "param_b": "1"}, {"param_a": "C", "param_b": "1"}) expected_multiple_arg_dict = [p for p in multiple_arg_dict if p["param_a"] in COMPLETE_TASKS] self.assertEqual(expected_multiple_arg_dict, MockTask.bulk_complete(multiple_arg_dict)) ================================================ FILE: test/task_forwarded_attributes_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from helpers import LuigiTestCase, RunOnceTask import luigi import luigi.scheduler import luigi.worker FORWARDED_ATTRIBUTES = set(luigi.worker.TaskProcess.forward_reporter_attributes.values()) class NonYieldingTask(RunOnceTask): # need to accept messages in order for the "scheduler_message" attribute to be not None accepts_messages = True def gather_forwarded_attributes(self): """ Returns a set of names of attributes that are forwarded by the TaskProcess and that are not *None*. The tests in this file check if and which attributes are present at different times, e.g. while running, or before and after a dynamic dependency was yielded. """ attrs = set() for attr in FORWARDED_ATTRIBUTES: if getattr(self, attr, None) is not None: attrs.add(attr) return attrs def run(self): # store names of forwarded attributes which are only available within the run method self.attributes_while_running = self.gather_forwarded_attributes() # invoke the run method of the RunOnceTask which marks this task as complete RunOnceTask.run(self) class YieldingTask(NonYieldingTask): def run(self): # as TaskProcess._run_get_new_deps handles generators in a specific way, store names of # forwarded attributes before and after yielding a dynamic dependency, so we can explicitly # validate the attribute forwarding implementation self.attributes_before_yield = self.gather_forwarded_attributes() yield RunOnceTask() self.attributes_after_yield = self.gather_forwarded_attributes() # invoke the run method of the RunOnceTask which marks this task as complete RunOnceTask.run(self) class TaskForwardedAttributesTest(LuigiTestCase): def run_task(self, task): sch = luigi.scheduler.Scheduler() with luigi.worker.Worker(scheduler=sch) as w: w.add(task) w.run() return task def test_non_yielding_task(self): task = self.run_task(NonYieldingTask()) self.assertEqual(task.attributes_while_running, FORWARDED_ATTRIBUTES) def test_yielding_task(self): task = self.run_task(YieldingTask()) self.assertEqual(task.attributes_before_yield, FORWARDED_ATTRIBUTES) self.assertEqual(task.attributes_after_yield, FORWARDED_ATTRIBUTES) ================================================ FILE: test/task_history_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from helpers import LuigiTestCase import luigi import luigi.scheduler import luigi.task_history import luigi.worker luigi.notifications.DEBUG = True class SimpleTaskHistory(luigi.task_history.TaskHistory): def __init__(self): self.actions = [] def task_scheduled(self, task): self.actions.append(("scheduled", task.id)) def task_finished(self, task, successful): self.actions.append(("finished", task.id)) def task_started(self, task, worker_host): self.actions.append(("started", task.id)) class TaskHistoryTest(LuigiTestCase): def test_run(self): th = SimpleTaskHistory() sch = luigi.scheduler.Scheduler(task_history_impl=th) with luigi.worker.Worker(scheduler=sch) as w: class MyTask(luigi.Task): pass task = MyTask() w.add(task) w.run() self.assertEqual(th.actions, [("scheduled", task.task_id), ("started", task.task_id), ("finished", task.task_id)]) ================================================ FILE: test/task_progress_percentage_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from helpers import LuigiTestCase import luigi import luigi.scheduler import luigi.worker class TaskProgressPercentageTest(LuigiTestCase): def test_run(self): sch = luigi.scheduler.Scheduler() with luigi.worker.Worker(scheduler=sch) as w: class MyTask(luigi.Task): def run(self): self.set_progress_percentage(30) task = MyTask() w.add(task) w.run() self.assertEqual(sch.get_task_progress_percentage(task.task_id)["progressPercentage"], 30) ================================================ FILE: test/task_register_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2017 VNG Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from helpers import LuigiTestCase import luigi from luigi.task_register import ( Register, TaskClassAmbigiousException, TaskClassNotFoundException, ) class TaskRegisterTest(LuigiTestCase): def test_externalize_taskclass(self): with self.assertRaises(TaskClassNotFoundException): Register.get_task_cls("scooby.Doo") class Task1(luigi.Task): @classmethod def get_task_family(cls): return "scooby.Doo" self.assertEqual(Task1, Register.get_task_cls("scooby.Doo")) class Task2(luigi.Task): @classmethod def get_task_family(cls): return "scooby.Doo" with self.assertRaises(TaskClassAmbigiousException): Register.get_task_cls("scooby.Doo") class Task3(luigi.Task): @classmethod def get_task_family(cls): return "scooby.Doo" # There previously was a rare bug where the third installed class could # "undo" class ambiguity. with self.assertRaises(TaskClassAmbigiousException): Register.get_task_cls("scooby.Doo") ================================================ FILE: test/task_running_resources_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import multiprocessing import os import signal import time from contextlib import contextmanager from helpers import RunOnceTask, skipOnGithubActions, unittest, with_config import luigi import luigi.server class ResourceTestTask(RunOnceTask): param = luigi.Parameter() reduce_foo = luigi.BoolParameter() def process_resources(self): return {"foo": 2} def run(self): if self.reduce_foo: self.decrease_running_resources({"foo": 1}) time.sleep(2) super(ResourceTestTask, self).run() class ResourceWrapperTask(RunOnceTask): reduce_foo = ResourceTestTask.reduce_foo def requires(self): return [ ResourceTestTask(param="a", reduce_foo=self.reduce_foo), ResourceTestTask(param="b"), ] class LocalRunningResourcesTest(unittest.TestCase): def test_resource_reduction(self): # trivial resource reduction on local scheduler # test the running_task_resources setter and getter sch = luigi.scheduler.Scheduler(resources={"foo": 2}) with luigi.worker.Worker(scheduler=sch) as w: task = ResourceTestTask(param="a", reduce_foo=True) w.add(task) w.run() self.assertEqual(sch.get_running_task_resources(task.task_id)["resources"]["foo"], 1) class ConcurrentRunningResourcesTest(unittest.TestCase): @with_config({"scheduler": {"stable_done_cooldown_secs": "0"}}) def setUp(self): super(ConcurrentRunningResourcesTest, self).setUp() # run the luigi server in a new process and wait for its startup self._process = multiprocessing.Process(target=luigi.server.run) self._process.start() time.sleep(0.5) # configure the rpc scheduler, update the foo resource self.sch = luigi.rpc.RemoteScheduler() self.sch.update_resource("foo", 3) def tearDown(self): super(ConcurrentRunningResourcesTest, self).tearDown() # graceful server shutdown self._process.terminate() self._process.join(timeout=1) if self._process.is_alive(): os.kill(self._process.pid, signal.SIGKILL) @contextmanager def worker(self, scheduler=None, processes=2): with luigi.worker.Worker(scheduler=scheduler or self.sch, worker_processes=processes) as w: w._config.wait_interval = 0.2 w._config.check_unfulfilled_deps = False yield w @contextmanager def assert_duration(self, min_duration=0, max_duration=-1): t0 = time.time() try: yield finally: duration = time.time() - t0 self.assertGreater(duration, min_duration) if max_duration > 0: self.assertLess(duration, max_duration) def test_tasks_serial(self): # serial test # run two tasks that do not reduce the "foo" resource # as the total foo resource (3) is smaller than the requirement of two tasks (4), # the scheduler is forced to run them serially which takes longer than 4 seconds with self.worker() as w: w.add(ResourceWrapperTask(reduce_foo=False)) with self.assert_duration(min_duration=4): w.run() @skipOnGithubActions("Temporary skipping on GH actions") # TODO: Fix and remove skip def test_tasks_parallel(self): # parallel test # run two tasks and the first one lowers its requirement on the "foo" resource, so that # the total "foo" resource (3) is sufficient to run both tasks in parallel shortly after # the first task started, so the entire process should not exceed 4 seconds with self.worker() as w: w.add(ResourceWrapperTask(reduce_foo=True)) with self.assert_duration(max_duration=4): w.run() ================================================ FILE: test/task_serialize_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ We want to test that task_id is consistent when generated from: 1. A real task instance 2. The task_family and a dictionary of parameter values (as strings) 3. A json representation of #2 We use the hypothesis package to do property-based tests. """ import json import string from datetime import datetime import hypothesis as hyp from hypothesis.strategies import datetimes as hyp_datetimes import luigi _no_value = luigi.parameter._no_value def _mk_param_strategy(param_cls, param_value_strat, with_default=None): if with_default is None: default = hyp.strategies.one_of(hyp.strategies.just(_no_value), param_value_strat) elif with_default: default = param_value_strat else: default = hyp.strategies.just(_no_value) return hyp.strategies.builds(param_cls, description=hyp.strategies.text(alphabet=string.printable), default=default) def _mk_task(name, params): return type(name, (luigi.Task,), params) # identifiers must be str not unicode in Python2 identifiers = hyp.strategies.builds(str, hyp.strategies.text(alphabet=string.ascii_letters, min_size=1, max_size=16)) text = hyp.strategies.text(alphabet=string.printable) # Luigi parameters with a default parameters_def = _mk_param_strategy(luigi.Parameter, text, True) int_parameters_def = _mk_param_strategy(luigi.IntParameter, hyp.strategies.integers(), True) float_parameters_def = _mk_param_strategy(luigi.FloatParameter, hyp.strategies.floats(min_value=-1e100, max_value=+1e100), True) bool_parameters_def = _mk_param_strategy(luigi.BoolParameter, hyp.strategies.booleans(), True) date_parameters_def = _mk_param_strategy(luigi.DateParameter, hyp_datetimes(min_value=datetime(1900, 1, 1)), True) any_default_parameters = hyp.strategies.one_of(parameters_def, int_parameters_def, float_parameters_def, bool_parameters_def, date_parameters_def) # Tasks with up to 3 random parameters tasks_with_defaults = hyp.strategies.builds(_mk_task, name=identifiers, params=hyp.strategies.dictionaries(identifiers, any_default_parameters, max_size=3)) def _task_to_dict(task): # Generate the parameter value dictionary. Use each parameter's serialize() method param_dict = {} for key, param in task.get_params(): param_dict[key] = param.serialize(getattr(task, key)) return param_dict def _task_from_dict(task_cls, param_dict): # Regenerate the task from the dictionary task_params = {} for key, param in task_cls.get_params(): task_params[key] = param.parse(param_dict[key]) return task_cls(**task_params) @hyp.given(tasks_with_defaults) def test_serializable(task_cls): task = task_cls() param_dict = _task_to_dict(task) task2 = _task_from_dict(task_cls, param_dict) assert task.task_id == task2.task_id @hyp.given(tasks_with_defaults) def test_json_serializable(task_cls): task = task_cls() param_dict = _task_to_dict(task) param_dict = json.loads(json.dumps(param_dict)) task2 = _task_from_dict(task_cls, param_dict) assert task.task_id == task2.task_id @hyp.given(tasks_with_defaults) def test_task_id_alphanumeric(task_cls): task = task_cls() task_id = task.task_id valid = string.ascii_letters + string.digits + "_" assert [x for x in task_id if x not in valid] == [] # TODO : significant an non-significant parameters ================================================ FILE: test/task_status_message_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from helpers import LuigiTestCase import luigi import luigi.scheduler import luigi.worker luigi.notifications.DEBUG = True class TaskStatusMessageTest(LuigiTestCase): def test_run(self): message = "test message" sch = luigi.scheduler.Scheduler() with luigi.worker.Worker(scheduler=sch) as w: class MyTask(luigi.Task): def run(self): self.set_status_message(message) task = MyTask() w.add(task) w.run() self.assertEqual(sch.get_task_status_message(task.task_id)["statusMessage"], message) ================================================ FILE: test/task_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import collections import doctest import pickle import warnings from datetime import datetime, timedelta from helpers import LuigiTestCase, unittest, with_config import luigi import luigi.task import luigi.util from luigi.task_register import load_task class DummyTask(luigi.Task): param = luigi.Parameter() bool_param = luigi.BoolParameter() int_param = luigi.IntParameter() float_param = luigi.FloatParameter() date_param = luigi.DateParameter() datehour_param = luigi.DateHourParameter() timedelta_param = luigi.TimeDeltaParameter() insignificant_param = luigi.Parameter(significant=False) DUMMY_TASK_OK_PARAMS = dict( param="test", bool_param=True, int_param=666, float_param=123.456, date_param=datetime(2014, 9, 13).date(), datehour_param=datetime(2014, 9, 13, 9), timedelta_param=timedelta(44), # doesn't support seconds insignificant_param="test", ) class DefaultInsignificantParamTask(luigi.Task): insignificant_param = luigi.Parameter(significant=False, default="value") necessary_param = luigi.Parameter(significant=False) class TaskTest(unittest.TestCase): def test_tasks_doctest(self): doctest.testmod(luigi.task) def test_task_to_str_to_task(self): original = DummyTask(**DUMMY_TASK_OK_PARAMS) other = DummyTask.from_str_params(original.to_str_params()) self.assertEqual(original, other) def test_task_from_str_insignificant(self): params = {"necessary_param": "needed"} original = DefaultInsignificantParamTask(**params) other = DefaultInsignificantParamTask.from_str_params(params) self.assertEqual(original, other) def test_task_missing_necessary_param(self): with self.assertRaises(luigi.parameter.MissingParameterException): DefaultInsignificantParamTask.from_str_params({}) def test_external_tasks_loadable(self): task = load_task("luigi", "ExternalTask", {}) self.assertTrue(isinstance(task, luigi.ExternalTask)) def test_getpaths(self): class RequiredTask(luigi.Task): def output(self): return luigi.LocalTarget("/path/to/target/file") t = RequiredTask() reqs = {} reqs["bare"] = t reqs["dict"] = {"key": t} reqs["OrderedDict"] = collections.OrderedDict([("key", t)]) reqs["list"] = [t] reqs["tuple"] = (t,) reqs["generator"] = (t for _ in range(10)) struct = luigi.task.getpaths(reqs) self.assertIsInstance(struct, dict) self.assertIsInstance(struct["bare"], luigi.Target) self.assertIsInstance(struct["dict"], dict) self.assertIsInstance(struct["OrderedDict"], collections.OrderedDict) self.assertIsInstance(struct["list"], list) self.assertIsInstance(struct["tuple"], tuple) self.assertTrue(hasattr(struct["generator"], "__iter__")) def test_flatten(self): flatten = luigi.task.flatten self.assertEqual(sorted(flatten({"a": "foo", "b": "bar"})), ["bar", "foo"]) self.assertEqual(sorted(flatten(["foo", ["bar", "troll"]])), ["bar", "foo", "troll"]) self.assertEqual(flatten("foo"), ["foo"]) self.assertEqual(flatten(42), [42]) self.assertEqual(flatten((len(i) for i in ["foo", "troll"])), [3, 5]) self.assertRaises(TypeError, flatten, (len(i) for i in ["foo", "troll", None])) def test_externalized_task_picklable(self): task = luigi.task.externalize(luigi.Task()) pickled_task = pickle.dumps(task) self.assertEqual(task, pickle.loads(pickled_task)) def test_no_unpicklable_properties(self): task = luigi.Task() task.set_tracking_url = lambda tracking_url: tracking_url task.set_status_message = lambda message: message with task.no_unpicklable_properties(): pickle.dumps(task) self.assertIsNotNone(task.set_tracking_url) self.assertIsNotNone(task.set_status_message) tracking_url = task.set_tracking_url("http://test.luigi.com/") self.assertEqual(tracking_url, "http://test.luigi.com/") message = task.set_status_message("message") self.assertEqual(message, "message") def test_no_warn_if_param_types_ok(self): with warnings.catch_warnings(record=True) as w: DummyTask(**DUMMY_TASK_OK_PARAMS) self.assertEqual(len(w), 0, msg="No warning should be raised when correct parameter types are used") def test_warn_on_non_str_param(self): params = dict(**DUMMY_TASK_OK_PARAMS) params["param"] = 42 with self.assertWarnsRegex(UserWarning, 'Parameter "param" with value "42" is not of type string.'): DummyTask(**params) def test_warn_on_non_timedelta_param(self): params = dict(**DUMMY_TASK_OK_PARAMS) class MockTimedelta: days = 1 seconds = 1 params["timedelta_param"] = MockTimedelta() with self.assertWarnsRegex(UserWarning, 'Parameter "timedelta_param" with value ".*" is not of type timedelta.'): DummyTask(**params) def test_disable_window_seconds(self): """ Deprecated disable_window_seconds param uses disable_window value """ class ATask(luigi.Task): disable_window = 17 task = ATask() self.assertEqual(task.disable_window_seconds, 17) @with_config({"ATaskWithBadParam": {"bad_param": "bad_value"}}) def test_bad_param(self): class ATaskWithBadParam(luigi.Task): bad_param = luigi.IntParameter() with self.assertRaisesRegex(ValueError, r"ATaskWithBadParam\[args=\(\), kwargs={}\]: Error when parsing the default value of 'bad_param'"): ATaskWithBadParam() @with_config( { "TaskA": { "a": "a", "b": "b", "c": "c", }, "TaskB": { "a": "a", "b": "b", "c": "c", }, } ) def test_unconsumed_params(self): class TaskA(luigi.Task): a = luigi.Parameter(default="a") class TaskB(luigi.Task): a = luigi.Parameter(default="a") with warnings.catch_warnings(record=True) as w: warnings.filterwarnings( action="ignore", category=Warning, ) warnings.simplefilter( action="always", category=luigi.parameter.UnconsumedParameterWarning, ) TaskA() TaskB() assert len(w) == 4 expected = [ ("b", "TaskA"), ("c", "TaskA"), ("b", "TaskB"), ("c", "TaskB"), ] for i, (expected_value, task_name) in zip(w, expected): assert issubclass(i.category, luigi.parameter.UnconsumedParameterWarning) assert str(i.message) == ( f"The configuration contains the parameter '{expected_value}' with value '{expected_value}' that is not consumed by the task '{task_name}'." ) @with_config( { "TaskEdgeCase": { "camelParam": "camelCase", "underscore_param": "underscore", "dash-param": "dash", }, } ) def test_unconsumed_params_edge_cases(self): class TaskEdgeCase(luigi.Task): camelParam = luigi.Parameter() underscore_param = luigi.Parameter() dash_param = luigi.Parameter() with warnings.catch_warnings(record=True) as w: warnings.filterwarnings( action="ignore", category=Warning, ) warnings.simplefilter( action="always", category=luigi.parameter.UnconsumedParameterWarning, ) task = TaskEdgeCase() assert len(w) == 0 assert task.camelParam == "camelCase" assert task.underscore_param == "underscore" assert task.dash_param == "dash" @with_config( { "TaskIgnoreUnconsumed": { "a": "a", "b": "b", "c": "c", }, } ) def test_unconsumed_params_ignore_unconsumed(self): class TaskIgnoreUnconsumed(luigi.Task): ignore_unconsumed = {"b", "d"} a = luigi.Parameter() with warnings.catch_warnings(record=True) as w: warnings.filterwarnings( action="ignore", category=Warning, ) warnings.simplefilter( action="always", category=luigi.parameter.UnconsumedParameterWarning, ) TaskIgnoreUnconsumed() assert len(w) == 1 class TaskFlattenOutputTest(unittest.TestCase): def test_single_task(self): expected = [luigi.LocalTarget("f1.txt"), luigi.LocalTarget("f2.txt")] class TestTask(luigi.ExternalTask): def output(self): return expected self.assertListEqual(luigi.task.flatten_output(TestTask()), expected) def test_wrapper_task(self): expected = [luigi.LocalTarget("f1.txt"), luigi.LocalTarget("f2.txt")] class Test1Task(luigi.ExternalTask): def output(self): return expected[0] class Test2Task(luigi.ExternalTask): def output(self): return expected[1] @luigi.util.requires(Test1Task, Test2Task) class TestWrapperTask(luigi.WrapperTask): pass self.assertListEqual(luigi.task.flatten_output(TestWrapperTask()), expected) def test_wrapper_tasks_diamond(self): expected = [luigi.LocalTarget("file.txt")] class TestTask(luigi.ExternalTask): def output(self): return expected @luigi.util.requires(TestTask) class LeftWrapperTask(luigi.WrapperTask): pass @luigi.util.requires(TestTask) class RightWrapperTask(luigi.WrapperTask): pass @luigi.util.requires(LeftWrapperTask, RightWrapperTask) class MasterWrapperTask(luigi.WrapperTask): pass self.assertListEqual(luigi.task.flatten_output(MasterWrapperTask()), expected) class ExternalizeTaskTest(LuigiTestCase): def test_externalize_taskclass(self): class MyTask(luigi.Task): def run(self): pass self.assertIsNotNone(MyTask.run) # Assert what we believe task_object = luigi.task.externalize(MyTask)() self.assertIsNone(task_object.run) self.assertIsNotNone(MyTask.run) # Check immutability self.assertIsNotNone(MyTask().run) # Check immutability def test_externalize_taskobject(self): class MyTask(luigi.Task): def run(self): pass task_object = luigi.task.externalize(MyTask()) self.assertIsNone(task_object.run) self.assertIsNotNone(MyTask.run) # Check immutability self.assertIsNotNone(MyTask().run) # Check immutability def test_externalize_taskclass_readable_name(self): class MyTask(luigi.Task): def run(self): pass task_class = luigi.task.externalize(MyTask) self.assertIsNot(task_class, MyTask) self.assertIn("MyTask", task_class.__name__) def test_externalize_taskclass_instance_cache(self): class MyTask(luigi.Task): def run(self): pass task_class = luigi.task.externalize(MyTask) self.assertIsNot(task_class, MyTask) self.assertIs(MyTask(), MyTask()) # Assert it have enabled the instance caching self.assertIsNot(task_class(), MyTask()) # Now, they should not be the same of course def test_externalize_same_id(self): class MyTask(luigi.Task): def run(self): pass task_normal = MyTask() task_ext_1 = luigi.task.externalize(MyTask)() task_ext_2 = luigi.task.externalize(MyTask()) self.assertEqual(task_normal.task_id, task_ext_1.task_id) self.assertEqual(task_normal.task_id, task_ext_2.task_id) def test_externalize_same_id_with_task_namespace(self): # Dependent on the new behavior from spotify/luigi#1953 class MyTask(luigi.Task): task_namespace = "something.domething" def run(self): pass task_normal = MyTask() task_ext_1 = luigi.task.externalize(MyTask()) task_ext_2 = luigi.task.externalize(MyTask)() self.assertEqual(task_normal.task_id, task_ext_1.task_id) self.assertEqual(task_normal.task_id, task_ext_2.task_id) self.assertEqual(str(task_normal), str(task_ext_1)) self.assertEqual(str(task_normal), str(task_ext_2)) def test_externalize_same_id_with_luigi_namespace(self): # Dependent on the new behavior from spotify/luigi#1953 luigi.namespace("lets.externalize") class MyTask(luigi.Task): def run(self): pass luigi.namespace() task_normal = MyTask() task_ext_1 = luigi.task.externalize(MyTask()) task_ext_2 = luigi.task.externalize(MyTask)() self.assertEqual(task_normal.task_id, task_ext_1.task_id) self.assertEqual(task_normal.task_id, task_ext_2.task_id) self.assertEqual(str(task_normal), str(task_ext_1)) self.assertEqual(str(task_normal), str(task_ext_2)) def test_externalize_with_requires(self): class MyTask(luigi.Task): def run(self): pass @luigi.util.requires(luigi.task.externalize(MyTask)) class Requirer(luigi.Task): def run(self): pass self.assertIsNotNone(MyTask.run) # Check immutability self.assertIsNotNone(MyTask().run) # Check immutability def test_externalize_doesnt_affect_the_registry(self): class MyTask(luigi.Task): pass reg_orig = luigi.task_register.Register._get_reg() luigi.task.externalize(MyTask) reg_afterwards = luigi.task_register.Register._get_reg() self.assertEqual(reg_orig, reg_afterwards) def test_can_uniquely_command_line_parse(self): class MyTask(luigi.Task): pass # This first check is just an assumption rather than assertion self.assertTrue(self.run_locally(["MyTask"])) luigi.task.externalize(MyTask) # Now we check we don't encounter "ambiguous task" issues self.assertTrue(self.run_locally(["MyTask"])) # We do this once again, is there previously was a bug like this. luigi.task.externalize(MyTask) self.assertTrue(self.run_locally(["MyTask"])) class TaskNamespaceTest(LuigiTestCase): def setup_tasks(self): class Foo(luigi.Task): pass class FooSubclass(Foo): pass return (Foo, FooSubclass, self.go_mynamespace()) def go_mynamespace(self): luigi.namespace("mynamespace") class Foo(luigi.Task): p = luigi.IntParameter() class Bar(Foo): task_namespace = "othernamespace" # namespace override class Baz(Bar): # inherits namespace for Bar pass luigi.namespace() return collections.namedtuple("mynamespace", "Foo Bar Baz")(Foo, Bar, Baz) def test_vanilla(self): (Foo, FooSubclass, namespace_test_helper) = self.setup_tasks() self.assertEqual(Foo.task_family, "Foo") self.assertEqual(str(Foo()), "Foo()") self.assertEqual(FooSubclass.task_family, "FooSubclass") self.assertEqual(str(FooSubclass()), "FooSubclass()") def test_namespace(self): (Foo, FooSubclass, namespace_test_helper) = self.setup_tasks() self.assertEqual(namespace_test_helper.Foo.task_family, "mynamespace.Foo") self.assertEqual(str(namespace_test_helper.Foo(1)), "mynamespace.Foo(p=1)") self.assertEqual(namespace_test_helper.Bar.task_namespace, "othernamespace") self.assertEqual(namespace_test_helper.Bar.task_family, "othernamespace.Bar") self.assertEqual(str(namespace_test_helper.Bar(1)), "othernamespace.Bar(p=1)") self.assertEqual(namespace_test_helper.Baz.task_namespace, "othernamespace") self.assertEqual(namespace_test_helper.Baz.task_family, "othernamespace.Baz") self.assertEqual(str(namespace_test_helper.Baz(1)), "othernamespace.Baz(p=1)") def test_uses_latest_namespace(self): luigi.namespace("a") class _BaseTask(luigi.Task): pass luigi.namespace("b") class _ChildTask(_BaseTask): pass luigi.namespace() # Reset everything child_task = _ChildTask() self.assertEqual(child_task.task_family, "b._ChildTask") self.assertEqual(str(child_task), "b._ChildTask()") def test_with_scope(self): luigi.namespace("wohoo", scope="task_test") luigi.namespace("bleh", scope="") class MyTask(luigi.Task): pass luigi.namespace(scope="task_test") luigi.namespace(scope="") self.assertEqual(MyTask.get_task_namespace(), "wohoo") def test_with_scope_not_matching(self): luigi.namespace("wohoo", scope="incorrect_namespace") luigi.namespace("bleh", scope="") class MyTask(luigi.Task): pass luigi.namespace(scope="incorrect_namespace") luigi.namespace(scope="") self.assertEqual(MyTask.get_task_namespace(), "bleh") class AutoNamespaceTest(LuigiTestCase): this_module = "task_test" def test_auto_namespace_global(self): luigi.auto_namespace() class MyTask(luigi.Task): pass luigi.namespace() self.assertEqual(MyTask.get_task_namespace(), self.this_module) def test_auto_namespace_scope(self): luigi.auto_namespace(scope="task_test") luigi.namespace("bleh", scope="") class MyTask(luigi.Task): pass luigi.namespace(scope="task_test") luigi.namespace(scope="") self.assertEqual(MyTask.get_task_namespace(), self.this_module) def test_auto_namespace_not_matching(self): luigi.auto_namespace(scope="incorrect_namespace") luigi.namespace("bleh", scope="") class MyTask(luigi.Task): pass luigi.namespace(scope="incorrect_namespace") luigi.namespace(scope="") self.assertEqual(MyTask.get_task_namespace(), "bleh") def test_auto_namespace_not_matching_2(self): luigi.auto_namespace(scope="incorrect_namespace") class MyTask(luigi.Task): pass luigi.namespace(scope="incorrect_namespace") self.assertEqual(MyTask.get_task_namespace(), "") class InitSubclassTest(LuigiTestCase): def test_task_works_with_init_subclass(self): class ReceivesClassKwargs(luigi.Task): def __init_subclass__(cls, x, **kwargs): super(ReceivesClassKwargs, cls).__init_subclass__() cls.x = x class Receiver(ReceivesClassKwargs, x=1): pass self.assertEqual(Receiver.x, 1) ================================================ FILE: test/test_sigpipe.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import os from helpers import unittest from luigi.format import InputPipeProcessWrapper BASH_SCRIPT = """ #!/bin/bash trap "touch /tmp/luigi_sigpipe.marker; exit 141" SIGPIPE for i in {1..3} do sleep 0.1 echo "Welcome $i times" done """ FAIL_SCRIPT = ( BASH_SCRIPT + """ exit 1 """ ) class TestSigpipe(unittest.TestCase): def setUp(self): with open("/tmp/luigi_test_sigpipe.sh", "w") as fp: fp.write(BASH_SCRIPT) def tearDown(self): os.remove("/tmp/luigi_test_sigpipe.sh") if os.path.exists("/tmp/luigi_sigpipe.marker"): os.remove("/tmp/luigi_sigpipe.marker") def test_partial_read(self): p1 = InputPipeProcessWrapper(["bash", "/tmp/luigi_test_sigpipe.sh"]) self.assertEqual(p1.readline().decode("utf8"), "Welcome 1 times\n") p1.close() self.assertTrue(os.path.exists("/tmp/luigi_sigpipe.marker")) def test_full_read(self): p1 = InputPipeProcessWrapper(["bash", "/tmp/luigi_test_sigpipe.sh"]) counter = 1 for line in p1: self.assertEqual(line.decode("utf8"), "Welcome %i times\n" % counter) counter += 1 p1.close() self.assertFalse(os.path.exists("/tmp/luigi_sigpipe.marker")) class TestSubprocessException(unittest.TestCase): def setUp(self): with open("/tmp/luigi_test_sigpipe.sh", "w") as fp: fp.write(FAIL_SCRIPT) def tearDown(self): os.remove("/tmp/luigi_test_sigpipe.sh") if os.path.exists("/tmp/luigi_sigpipe.marker"): os.remove("/tmp/luigi_sigpipe.marker") def test_partial_read(self): p1 = InputPipeProcessWrapper(["bash", "/tmp/luigi_test_sigpipe.sh"]) self.assertEqual(p1.readline().decode("utf8"), "Welcome 1 times\n") p1.close() self.assertTrue(os.path.exists("/tmp/luigi_sigpipe.marker")) def test_full_read(self): def run(): p1 = InputPipeProcessWrapper(["bash", "/tmp/luigi_test_sigpipe.sh"]) counter = 1 for line in p1: self.assertEqual(line.decode("utf8"), "Welcome %i times\n" % counter) counter += 1 p1.close() self.assertRaises(RuntimeError, run) ================================================ FILE: test/test_ssh.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import subprocess from helpers import unittest from luigi.contrib.ssh import RemoteContext class TestMockedRemoteContext(unittest.TestCase): def test_subprocess_delegation(self): """Test subprocess call structure using mock module""" orig_Popen = subprocess.Popen self.last_test = None def Popen(cmd, **kwargs): self.last_test = cmd subprocess.Popen = Popen context = RemoteContext("some_host", username="luigi", key_file="/some/key.pub") context.Popen(["ls"]) self.assertTrue("ssh" in self.last_test) self.assertTrue("-i" in self.last_test) self.assertTrue("/some/key.pub" in self.last_test) self.assertTrue("luigi@some_host" in self.last_test) self.assertTrue("ls" in self.last_test) subprocess.Popen = orig_Popen def test_check_output_fail_connect(self): """Test check_output to a non-existing host""" context = RemoteContext("__NO_HOST_LIKE_THIS__", connect_timeout=1) self.assertRaises(subprocess.CalledProcessError, context.check_output, ["ls"]) ================================================ FILE: test/testconfig/core-site.xml ================================================ <?xml version="1.0"?> <?xml-stylesheet type="text/xsl" href="configuration.xsl"?> <configuration> <property> <name>fs.defaultFS</name> <value>hdfs://localhost:50030/</value> </property> </configuration> ================================================ FILE: test/testconfig/log4j.properties ================================================ hadoop.root.logger=INFO,stderr log4j.logger.org.apache.hadoop=INFO,stderr log4j.logger.org.apache.hadoop.util.NativeCodeLoader=Off log4j.appender.stderr = org.apache.log4j.ConsoleAppender log4j.appender.stderr.layout = org.apache.log4j.PatternLayout log4j.appender.stderr.Target = System.err ================================================ FILE: test/testconfig/logging.cfg ================================================ [loggers] keys=root [handlers] keys=consoleHandler [formatters] keys=simpleFormatter [logger_root] level=DEBUG handlers=consoleHandler [handler_consoleHandler] class=StreamHandler level=DEBUG formatter=simpleFormatter args=(sys.stdout,) [formatter_simpleFormatter] format=%(levelname)s: %(message)s ================================================ FILE: test/testconfig/luigi.toml ================================================ [core] logging_conf_file = "test/testconfig/logging.cfg" [hdfs] client = "hadoopcli" snakebite_autoconfig = false namenode_host = "must be overridden in local config" [SomeTask] param = {key1 = "value1", key2 = "value2"} ================================================ FILE: test/testconfig/luigi_local.toml ================================================ [hdfs] namenode_host = "localhost" namenode_port = 50030 ================================================ FILE: test/testconfig/luigi_logging.toml ================================================ [logging] version = 1 disable_existing_loggers = false [logging.formatters.mockformatter] format = "{levelname}: {message}" style = "{" [logging.handlers.mockhandler] class = "logging.StreamHandler" level = "INFO" formatter = "mockformatter" [logging.loggers.mocklogger] handlers = ["mockhandler"] level = 'INFO' disabled = false propagate = false ================================================ FILE: test/testconfig/pyproject.toml ================================================ [tool.mypy] plugins = ["luigi.mypy"] ignore_missing_imports = true ================================================ FILE: test/util_previous_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import datetime from helpers import unittest import luigi import luigi.date_interval from luigi.util import get_previous_completed, previous class DateTaskOk(luigi.Task): date = luigi.DateParameter() def complete(self): # test against 2000.03.01 return self.date in [datetime.date(2000, 2, 25), datetime.date(2000, 3, 1), datetime.date(2000, 3, 2)] class DateTaskOkTest(unittest.TestCase): def test_previous(self): task = DateTaskOk(datetime.date(2000, 3, 1)) prev = previous(task) self.assertEqual(prev.date, datetime.date(2000, 2, 29)) def test_get_previous_completed(self): task = DateTaskOk(datetime.date(2000, 3, 1)) prev = get_previous_completed(task, 5) self.assertEqual(prev.date, datetime.date(2000, 2, 25)) def test_get_previous_completed_not_found(self): task = DateTaskOk(datetime.date(2000, 3, 1)) prev = get_previous_completed(task, 4) self.assertEqual(None, prev) class DateHourTaskOk(luigi.Task): hour = luigi.DateHourParameter() def complete(self): # test against 2000.03.01T02 return self.hour in [datetime.datetime(2000, 2, 29, 22), datetime.datetime(2000, 3, 1, 2), datetime.datetime(2000, 3, 1, 3)] class DateHourTaskOkTest(unittest.TestCase): def test_previous(self): task = DateHourTaskOk(datetime.datetime(2000, 3, 1, 2)) prev = previous(task) self.assertEqual(prev.hour, datetime.datetime(2000, 3, 1, 1)) def test_get_previous_completed(self): task = DateHourTaskOk(datetime.datetime(2000, 3, 1, 2)) prev = get_previous_completed(task, 4) self.assertEqual(prev.hour, datetime.datetime(2000, 2, 29, 22)) def test_get_previous_completed_not_found(self): task = DateHourTaskOk(datetime.datetime(2000, 3, 1, 2)) prev = get_previous_completed(task, 3) self.assertEqual(None, prev) class DateMinuteTaskOk(luigi.Task): minute = luigi.DateMinuteParameter() def complete(self): # test against 2000.03.01T02H03 return self.minute in [datetime.datetime(2000, 3, 1, 2, 0)] class DateMinuteTaskOkTest(unittest.TestCase): def test_previous(self): task = DateMinuteTaskOk(datetime.datetime(2000, 3, 1, 2, 3)) prev = previous(task) self.assertEqual(prev.minute, datetime.datetime(2000, 3, 1, 2, 2)) def test_get_previous_completed(self): task = DateMinuteTaskOk(datetime.datetime(2000, 3, 1, 2, 3)) prev = get_previous_completed(task, 3) self.assertEqual(prev.minute, datetime.datetime(2000, 3, 1, 2, 0)) def test_get_previous_completed_not_found(self): task = DateMinuteTaskOk(datetime.datetime(2000, 3, 1, 2, 3)) prev = get_previous_completed(task, 2) self.assertEqual(None, prev) class DateSecondTaskOk(luigi.Task): second = luigi.DateSecondParameter() def complete(self): return self.second in [datetime.datetime(2000, 3, 1, 2, 3, 4)] class DateSecondTaskOkTest(unittest.TestCase): def test_previous(self): task = DateSecondTaskOk(datetime.datetime(2000, 3, 1, 2, 3, 7)) prev = previous(task) self.assertEqual(prev.second, datetime.datetime(2000, 3, 1, 2, 3, 6)) def test_get_previous_completed(self): task = DateSecondTaskOk(datetime.datetime(2000, 3, 1, 2, 3, 7)) prev = get_previous_completed(task, 3) self.assertEqual(prev.second, datetime.datetime(2000, 3, 1, 2, 3, 4)) def test_get_previous_completed_not_found(self): task = DateSecondTaskOk(datetime.datetime(2000, 3, 1, 2, 3)) prev = get_previous_completed(task, 2) self.assertEqual(None, prev) class DateIntervalTaskOk(luigi.Task): interval = luigi.DateIntervalParameter() def complete(self): return self.interval in [luigi.date_interval.Week(1999, 48), luigi.date_interval.Week(2000, 1), luigi.date_interval.Week(2000, 2)] class DateIntervalTaskOkTest(unittest.TestCase): def test_previous(self): task = DateIntervalTaskOk(luigi.date_interval.Week(2000, 1)) prev = previous(task) self.assertEqual(prev.interval, luigi.date_interval.Week(1999, 52)) def test_get_previous_completed(self): task = DateIntervalTaskOk(luigi.date_interval.Week(2000, 1)) prev = get_previous_completed(task, 5) self.assertEqual(prev.interval, luigi.date_interval.Week(1999, 48)) def test_get_previous_completed_not_found(self): task = DateIntervalTaskOk(luigi.date_interval.Week(2000, 1)) prev = get_previous_completed(task, 4) self.assertEqual(None, prev) class ExtendedDateTaskOk(DateTaskOk): param1 = luigi.Parameter() param2 = luigi.IntParameter(default=2) class ExtendedDateTaskOkTest(unittest.TestCase): def test_previous(self): task = ExtendedDateTaskOk(datetime.date(2000, 3, 1), "some value") prev = previous(task) self.assertEqual(prev.date, datetime.date(2000, 2, 29)) self.assertEqual(prev.param1, "some value") self.assertEqual(prev.param2, 2) class MultiTemporalTaskNok(luigi.Task): date = luigi.DateParameter() hour = luigi.DateHourParameter() class MultiTemporalTaskNokTest(unittest.TestCase): def test_previous(self): task = MultiTemporalTaskNok(datetime.date(2000, 1, 1), datetime.datetime(2000, 1, 1, 1)) self.assertRaises(NotImplementedError, previous, task) self.assertRaises(NotImplementedError, get_previous_completed, task) class NoTemporalTaskNok(luigi.Task): param = luigi.Parameter() class NoTemporalTaskNokTest(unittest.TestCase): def test_previous(self): task = NoTemporalTaskNok("some value") self.assertRaises(NotImplementedError, previous, task) self.assertRaises(NotImplementedError, get_previous_completed, task) ================================================ FILE: test/util_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2016 VNG Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from helpers import LuigiTestCase, RunOnceTask import luigi import luigi.task from luigi.util import inherits, requires class BasicsTest(LuigiTestCase): # following tests using inherits decorator def test_task_ids_using_inherits(self): class ParentTask(luigi.Task): my_param = luigi.Parameter() luigi.namespace("blah") @inherits(ParentTask) class ChildTask(luigi.Task): def requires(self): return self.clone(ParentTask) luigi.namespace("") child_task = ChildTask(my_param="hello") self.assertEqual(str(child_task), "blah.ChildTask(my_param=hello)") self.assertIn(ParentTask(my_param="hello"), luigi.task.flatten(child_task.requires())) def test_task_ids_using_inherits_2(self): # Here we use this decorator in a unnormal way. # But it should still work. class ParentTask(luigi.Task): my_param = luigi.Parameter() decorator = inherits(ParentTask) luigi.namespace("blah") class ChildTask(luigi.Task): def requires(self): return self.clone_parent() luigi.namespace("") ChildTask = decorator(ChildTask) child_task = ChildTask(my_param="hello") self.assertEqual(str(child_task), "blah.ChildTask(my_param=hello)") self.assertIn(ParentTask(my_param="hello"), luigi.task.flatten(child_task.requires())) def test_task_ids_using_inherits_kwargs(self): class ParentTask(luigi.Task): my_param = luigi.Parameter() luigi.namespace("blah") @inherits(parent=ParentTask) class ChildTask(luigi.Task): def requires(self): return self.clone(ParentTask) luigi.namespace("") child_task = ChildTask(my_param="hello") self.assertEqual(str(child_task), "blah.ChildTask(my_param=hello)") self.assertIn(ParentTask(my_param="hello"), luigi.task.flatten(child_task.requires())) def _setup_parent_and_child_inherits(self): class ParentTask(luigi.Task): my_parameter = luigi.Parameter() class_variable = "notset" def run(self): self.__class__.class_variable = self.my_parameter def complete(self): return self.class_variable == "actuallyset" @inherits(ParentTask) class ChildTask(RunOnceTask): def requires(self): return self.clone_parent() return ParentTask def test_inherits_has_effect_run_child(self): ParentTask = self._setup_parent_and_child_inherits() self.assertTrue(self.run_locally_split("ChildTask --my-parameter actuallyset")) self.assertEqual(ParentTask.class_variable, "actuallyset") def test_inherits_has_effect_run_parent(self): ParentTask = self._setup_parent_and_child_inherits() self.assertTrue(self.run_locally_split("ParentTask --my-parameter actuallyset")) self.assertEqual(ParentTask.class_variable, "actuallyset") def _setup_inherits_inheritence(self): class InheritedTask(luigi.Task): pass class ParentTask(luigi.Task): pass @inherits(InheritedTask) class ChildTask(ParentTask): pass return ChildTask def test_inherits_has_effect_MRO(self): ChildTask = self._setup_inherits_inheritence() self.assertNotEqual(str(ChildTask.__mro__[0]), str(ChildTask.__mro__[1])) # following tests using requires decorator def test_task_ids_using_requries(self): class ParentTask(luigi.Task): my_param = luigi.Parameter() luigi.namespace("blah") @requires(ParentTask) class ChildTask(luigi.Task): pass luigi.namespace("") child_task = ChildTask(my_param="hello") self.assertEqual(str(child_task), "blah.ChildTask(my_param=hello)") self.assertIn(ParentTask(my_param="hello"), luigi.task.flatten(child_task.requires())) def test_task_ids_using_requries_2(self): # Here we use this decorator in a unnormal way. # But it should still work. class ParentTask(luigi.Task): my_param = luigi.Parameter() decorator = requires(ParentTask) luigi.namespace("blah") class ChildTask(luigi.Task): pass luigi.namespace("") ChildTask = decorator(ChildTask) child_task = ChildTask(my_param="hello") self.assertEqual(str(child_task), "blah.ChildTask(my_param=hello)") self.assertIn(ParentTask(my_param="hello"), luigi.task.flatten(child_task.requires())) def _setup_parent_and_child(self): class ParentTask(luigi.Task): my_parameter = luigi.Parameter() class_variable = "notset" def run(self): self.__class__.class_variable = self.my_parameter def complete(self): return self.class_variable == "actuallyset" @requires(ParentTask) class ChildTask(RunOnceTask): pass return ParentTask def test_requires_has_effect_run_child(self): ParentTask = self._setup_parent_and_child() self.assertTrue(self.run_locally_split("ChildTask --my-parameter actuallyset")) self.assertEqual(ParentTask.class_variable, "actuallyset") def test_requires_has_effect_run_parent(self): ParentTask = self._setup_parent_and_child() self.assertTrue(self.run_locally_split("ParentTask --my-parameter actuallyset")) self.assertEqual(ParentTask.class_variable, "actuallyset") def _setup_requires_inheritence(self): class RequiredTask(luigi.Task): pass class ParentTask(luigi.Task): pass @requires(RequiredTask) class ChildTask(ParentTask): pass return ChildTask def test_requires_has_effect_MRO(self): ChildTask = self._setup_requires_inheritence() self.assertNotEqual(str(ChildTask.__mro__[0]), str(ChildTask.__mro__[1])) def test_kwargs_requires_gives_named_inputs(self): class ParentTask(RunOnceTask): def output(self): return "Target" @requires(parent_1=ParentTask, parent_2=ParentTask) class ChildTask(RunOnceTask): resulting_input = "notset" def run(self): self.__class__.resulting_input = self.input() self.assertTrue(self.run_locally_split("ChildTask")) self.assertEqual(ChildTask.resulting_input, {"parent_1": "Target", "parent_2": "Target"}) ================================================ FILE: test/visible_parameters_test.py ================================================ import json from helpers import unittest import luigi from luigi.parameter import ParameterVisibility class TestTask1(luigi.Task): param_one = luigi.Parameter(default="1", visibility=ParameterVisibility.HIDDEN, significant=True) param_two = luigi.Parameter(default="2", significant=True) param_three = luigi.Parameter(default="3", visibility=ParameterVisibility.PRIVATE, significant=True) class TestTask2(luigi.Task): param_one = luigi.Parameter(default="1", visibility=ParameterVisibility.PRIVATE) param_two = luigi.Parameter(default="2", visibility=ParameterVisibility.PRIVATE) param_three = luigi.Parameter(default="3", visibility=ParameterVisibility.PRIVATE) class TestTask3(luigi.Task): param_one = luigi.Parameter(default="1", visibility=ParameterVisibility.HIDDEN, significant=True) param_two = luigi.Parameter(default="2", visibility=ParameterVisibility.HIDDEN, significant=False) param_three = luigi.Parameter(default="3", visibility=ParameterVisibility.HIDDEN, significant=True) class TestTask4(luigi.Task): param_one = luigi.Parameter(default="1", visibility=ParameterVisibility.PUBLIC, significant=True) param_two = luigi.Parameter(default="2", visibility=ParameterVisibility.PUBLIC, significant=False) param_three = luigi.Parameter(default="3", visibility=ParameterVisibility.PUBLIC, significant=True) class Test(unittest.TestCase): def test_to_str_params(self): task = TestTask1() self.assertEqual(task.to_str_params(), {"param_one": "1", "param_two": "2"}) task = TestTask2() self.assertEqual(task.to_str_params(), {}) task = TestTask3() self.assertEqual(task.to_str_params(), {"param_one": "1", "param_two": "2", "param_three": "3"}) def test_all_public_equals_all_hidden(self): hidden = TestTask3() public = TestTask4() self.assertEqual(public.to_str_params(), hidden.to_str_params()) def test_all_public_equals_all_hidden_using_significant(self): hidden = TestTask3() public = TestTask4() self.assertEqual(public.to_str_params(only_significant=True), hidden.to_str_params(only_significant=True)) def test_private_params_and_significant(self): task = TestTask1() self.assertEqual(task.to_str_params(), task.to_str_params(only_significant=True)) def test_param_visibilities(self): task = TestTask1() self.assertEqual(task._get_param_visibilities(), {"param_one": 1, "param_two": 0}) def test_incorrect_visibility_value(self): class Task(luigi.Task): a = luigi.Parameter(default="val", visibility=5) task = Task() self.assertEqual(task._get_param_visibilities(), {"a": 0}) def test_task_id_exclude_hidden_and_private_params(self): task = TestTask1() self.assertEqual({"param_two": "2"}, task.to_str_params(only_public=True)) def test_json_dumps(self): public = json.dumps(ParameterVisibility.PUBLIC.serialize()) hidden = json.dumps(ParameterVisibility.HIDDEN.serialize()) private = json.dumps(ParameterVisibility.PRIVATE.serialize()) self.assertEqual("0", public) self.assertEqual("1", hidden) self.assertEqual("2", private) public = json.loads(public) hidden = json.loads(hidden) private = json.loads(private) self.assertEqual(0, public) self.assertEqual(1, hidden) self.assertEqual(2, private) ================================================ FILE: test/visualiser/__init__.py ================================================ # Tests for visualiser javascript. ================================================ FILE: test/visualiser/phantomjs_test.js ================================================ var page = require('webpage').create(); var system = require('system'); var tests = []; /* * Parse command line to get Luigi scheduler URL */ if (system.args.length === 1) { console.log('Usage: phantom_test.js <scheduler-url>'); phantom.exit(); } var url = system.args[1]; /* * Minimal test framework */ function do_tests(page) { var ok = true; var retval; tests.forEach(function (spec) { var name = spec[0]; var test_func = spec[1]; retval = report(page.evaluate(test_func), name); ok = ok && retval; }); return ok; } function report(retval, func_name) { if (retval === true) { console.log('[ OK ] ' + func_name); return true; } else { console.log('[FAIL] ' + func_name); console.log(retval); return false; } } phantom.onError = function(msg, trace) { var msgStack = ['PHANTOM ERROR: ' + msg]; if (trace && trace.length) { msgStack.push('TRACE:'); trace.forEach(function(t) { msgStack.push(' -> ' + (t.file || t.sourceURL) + ': ' + t.line + (t.function ? ' (in function ' + t.function +')' : '')); }); } console.error(msgStack.join('\n')); phantom.exit(1); }; page.onError = function(msg, trace) { var msgStack = ['ERROR: ' + msg]; if (trace && trace.length) { msgStack.push('TRACE:'); trace.forEach(function(t) { msgStack.push(' -> ' + t.file + ': ' + t.line + (t.function ? ' (in function "' + t.function +'")' : '')); }); } console.error(msgStack.join('\n')); }; /** * def_test: define a test * @param test_name: Name of test * @param func: A function which will be evaluated within the page and should return * true for success or any other value for failure. */ function def_test(test_name, func) { tests.push([test_name, func]); } /* * Test definitions */ def_test('failed_info_test', function () { var el = $('#FAILED_info .info-box-number')[0]; if (el.textContent === "4") { return true; } else { return el.textContent; } }); def_test('done_info_test', function () { var el = $('#DONE_info .info-box-number')[0]; if (el.textContent === "68") { return true; } else { return el.textContent; } }); def_test('upstream_failure_info_test', function () { var el = $('#UPSTREAM_FAILED_info .info-box-number')[0]; if (el.textContent === '45') { return true; } else { return el.textContent; } }); def_test('result_count_test', function () { var el = $('#taskTable_info')[0]; if (el.textContent.match(/Showing \d+ to \d+ of 117 entries/)) { return true; } else { return el.textContent; } }); def_test('filtered_result_count_test1', function () { var ret; var target = $('ul.sidebar-menu li a').first(); target.click(); var el = $('#taskTable_info')[0]; if (el.textContent.match(/Showing \d+ to \d+ of 29 entries.*from 117 total entries/)) { ret = true; } else { ret = el.textContent; } target.click(); return ret; }); def_test('filtered_result_count_test2', function () { var ret; var target = $('#FAILED_info').first(); target.click(); var el = $('#taskTable_info')[0]; if (el.textContent.match(/Showing \d+ to \d+ of 4 entries.*from 117 total entries/)) { ret = true; } else { ret = el.textContent; } target.click(); return ret; }); def_test('filtered_result_count_test3', function () { var ret; var target = $('#PENDING_info').first(); target.click(); var el = $('#taskTable_info')[0]; if (el.textContent.match(/Showing \d+ to \d+ of 0 entries.*from 117 total entries/)) { ret = true; } else { ret = el.textContent; } target.click(); return ret; }); def_test('filtered_result_count_test4', function () { var ret; var target = $('#RUNNING_info').first(); target.click(); var el = $('#taskTable_info')[0]; if (el.textContent.match(/Showing \d+ to \d+ of 0 entries.*from 117 total entries/)) { ret = true; } else { ret = el.textContent; } target.click(); return ret; }); def_test('filtered_result_count_test5', function () { var ret; var target = $('#DONE_info').first(); target.click(); var el = $('#taskTable_info')[0]; if (el.textContent.match(/Showing \d+ to \d+ of 68 entries.*from 117 total entries/)) { ret = true; } else { ret = el.textContent; } target.click(); return ret; }); def_test('filtered_result_count_test5', function () { var ret; var target = $('#DISABLED_info').first(); target.click(); var el = $('#taskTable_info')[0]; if (el.textContent.match(/Showing \d+ to \d+ of 0 entries.*from 117 total entries/)) { ret = true; } else { ret = el.textContent; } target.click(); return ret; }); def_test('filtered_result_count_test5', function () { var ret; var target = $('#UPSTREAM_DISABLED_info').first(); target.click(); var el = $('#taskTable_info')[0]; if (el.textContent.match(/Showing \d+ to \d+ of 0 entries.*from 117 total entries/)) { ret = true; } else { ret = el.textContent; } target.click(); return ret; }); def_test('filtered_result_count_test5', function () { var ret; var target = $('#UPSTREAM_FAILED_info').first(); target.click(); var el = $('#taskTable_info')[0]; if (el.textContent.match(/Showing \d+ to \d+ of 45 entries.*from 117 total entries/)) { ret = true; } else { ret = el.textContent; } target.click(); return ret; }); def_test('searched_result_count_test1', function () { var ret; var dt = $('#taskTable').DataTable(); dt.search('FailingMergeSort_1').draw(); var el = $('#taskTable_info')[0]; if (el.textContent.match(/Showing \d+ to \d+ of 29 entries.*from 117 total entries/)) { ret = true; } else { ret = el.textContent; } dt.search('').draw(); return ret; }); def_test('searched_result_count_test1', function () { var ret; var target = $('#serverSide label').first(); var dt = $('#taskTable').DataTable(); target.click(); dt.search('FailingMergeSort_1').draw(); var el = $('#taskTable_info')[0]; if (el.textContent.match(/Showing \d+ to \d+ of 29 entries.*from 117 total entries/)) { ret = true; } else { ret = el.textContent; } target.click(); dt.search('').draw(); return ret; }); page.open(url, function(status) { var ok; console.log("Loaded " + url + ", status: " + status); if(status === "success") { ok = do_tests(page); } console.log('RESULT: ' + ok); phantom.exit(ok === true ? 0 : -1); }); ================================================ FILE: test/visualiser/visualiser_test.py ================================================ """ Test the visualiser's javascript using PhantomJS. """ import os import subprocess import sys import threading import time import unittest from selenium import webdriver import luigi here = os.path.dirname(__file__) # Patch-up path so that we can import from the directory above this one.r # This seems to be necessary because the `test` directory has no __init__.py but # adding one makes other tests fail. sys.path.append(os.path.join(here, "..")) from server_test import ServerTestBase # noqa TEST_TIMEOUT = 40 @unittest.skipUnless(os.environ.get("TEST_VISUALISER"), "PhantomJS tests not requested in TEST_VISUALISER") class TestVisualiser(ServerTestBase): """ Builds a medium-sized task tree of MergeSort results then starts phantomjs as a subprocess to interact with the scheduler. """ def setUp(self): super(TestVisualiser, self).setUp() x = "I scream for ice cream" task = UberTask(base_task=FailingMergeSort, x=x, copies=4) luigi.build([task], workers=1, scheduler_port=self.get_http_port()) self.done = threading.Event() def _do_ioloop(): # Enter ioloop for maximum TEST_TIMEOUT. Check every 2s whether the test has finished. print("Entering event loop in separate thread") for i in range(TEST_TIMEOUT): try: self.wait(timeout=1) except AssertionError: pass if self.done.is_set(): break print("Exiting event loop thread") self.iothread = threading.Thread(target=_do_ioloop) self.iothread.start() def tearDown(self): self.done.set() self.iothread.join() def test(self): port = self.get_http_port() print("Server port is {}".format(port)) print("Starting phantomjs") p = subprocess.Popen("phantomjs {}/phantomjs_test.js http://localhost:{}".format(here, port), shell=True, stdin=None) # PhantomJS may hang on an error so poll status = None for x in range(TEST_TIMEOUT): status = p.poll() if status is not None: break time.sleep(1) if status is None: raise AssertionError("PhantomJS failed to complete") else: print("PhantomJS return status is {}".format(status)) assert status == 0 # tasks tab tests. def test_keeps_entries_after_page_refresh(self): port = self.get_http_port() driver = webdriver.PhantomJS() driver.get("http://localhost:{}".format(port)) length_select = driver.find_element_by_css_selector('select[name="taskTable_length"]') assert length_select.get_attribute("value") == "10" assert len(driver.find_elements_by_css_selector("#taskTable tbody tr")) == 10 # Now change entries select box and check again. clicked = False for option in length_select.find_elements_by_css_selector("option"): if option.text == "50": option.click() clicked = True break assert clicked, 'Could not click option with "50" entries.' assert length_select.get_attribute("value") == "50" assert len(driver.find_elements_by_css_selector("#taskTable tbody tr")) == 50 # Now refresh page and check. Select box should be 50 and table should contain 50 rows. driver.refresh() # Once page refreshed we have to find all selectors again. length_select = driver.find_element_by_css_selector('select[name="taskTable_length"]') assert length_select.get_attribute("value") == "50" assert len(driver.find_elements_by_css_selector("#taskTable tbody tr")) == 50 def test_keeps_table_filter_after_page_refresh(self): port = self.get_http_port() driver = webdriver.PhantomJS() driver.get("http://localhost:{}".format(port)) # Check initial state. search_input = driver.find_element_by_css_selector('input[type="search"]') assert search_input.get_attribute("value") == "" assert len(driver.find_elements_by_css_selector("#taskTable tbody tr")) == 10 # Now filter and check filtered table. search_input.send_keys("ber") # UberTask only should be displayed. assert len(driver.find_elements_by_css_selector("#taskTable tbody tr")) == 1 # Now refresh page and check. Filter input should contain 'ber' and table should contain # one row (UberTask). driver.refresh() # Once page refreshed we have to find all selectors again. search_input = driver.find_element_by_css_selector('input[type="search"]') assert search_input.get_attribute("value") == "ber" assert len(driver.find_elements_by_css_selector("#taskTable tbody tr")) == 1 def test_keeps_order_after_page_refresh(self): port = self.get_http_port() driver = webdriver.PhantomJS() driver.get("http://localhost:{}".format(port)) # Order by name (asc). column = driver.find_elements_by_css_selector("#taskTable thead th")[1] column.click() table_body = driver.find_element_by_css_selector("#taskTable tbody") assert self._get_cell_value(table_body, 0, 1) == "FailingMergeSort_0" # Ordery by name (desc). column.click() assert self._get_cell_value(table_body, 0, 1) == "UberTask" # Now refresh page and check. Table should be ordered by name (desc). driver.refresh() # Once page refreshed we have to find all selectors again. table_body = driver.find_element_by_css_selector("#taskTable tbody") assert self._get_cell_value(table_body, 0, 1) == "UberTask" def test_keeps_filter_on_server_after_page_refresh(self): port = self.get_http_port() driver = webdriver.PhantomJS() driver.get("http://localhost:{}/static/visualiser/index.html#tab=tasks".format(port)) # Check initial state. checkbox = driver.find_element_by_css_selector("#serverSideCheckbox") assert checkbox.is_selected() is False # Change invert checkbox. checkbox.click() # Now refresh page and check. Invert checkbox shoud be checked. driver.refresh() # Once page refreshed we have to find all selectors again. checkbox = driver.find_element_by_css_selector("#serverSideCheckbox") assert checkbox.is_selected() def test_synchronizes_fields_on_tasks_tab(self): # Check fields population if tasks tab was opened by direct link port = self.get_http_port() driver = webdriver.PhantomJS() url = "http://localhost:{}/static/visualiser/index.html#tab=tasks&length=50&search__search=er&filterOnServer=1&order=1,desc".format(port) driver.get(url) length_select = driver.find_element_by_css_selector('select[name="taskTable_length"]') assert length_select.get_attribute("value") == "50" search_input = driver.find_element_by_css_selector('input[type="search"]') assert search_input.get_attribute("value") == "er" assert len(driver.find_elements_by_css_selector("#taskTable tbody tr")) == 50 # Table is ordered by first column (name) table_body = driver.find_element_by_css_selector("#taskTable tbody") assert self._get_cell_value(table_body, 0, 1) == "UberTask" # graph tab tests. def test_keeps_invert_after_page_refresh(self): port = self.get_http_port() driver = webdriver.PhantomJS() driver.get("http://localhost:{}/static/visualiser/index.html#tab=graph".format(port)) # Check initial state. invert_checkbox = driver.find_element_by_css_selector("#invertCheckbox") assert invert_checkbox.is_selected() is False # Change invert checkbox. invert_checkbox.click() # Now refresh page and check. Invert checkbox shoud be checked. driver.refresh() # Once page refreshed we have to find all selectors again. invert_checkbox = driver.find_element_by_css_selector("#invertCheckbox") assert invert_checkbox.is_selected() def test_keeps_task_id_after_page_refresh(self): port = self.get_http_port() driver = webdriver.PhantomJS() driver.get("http://localhost:{}/static/visualiser/index.html#tab=graph".format(port)) # Check initial state. task_id_input = driver.find_element_by_css_selector("#js-task-id") assert task_id_input.get_attribute("value") == "" # Change task id task_id_input.send_keys("1") driver.find_element_by_css_selector("#loadTaskForm button[type=submit]").click() # Now refresh page and check. Task ID field should contain 1 driver.refresh() # Once page refreshed we have to find all selectors again. task_id_input = driver.find_element_by_css_selector("#js-task-id") assert task_id_input.get_attribute("value") == "1" def test_keeps_hide_done_after_page_refresh(self): port = self.get_http_port() driver = webdriver.PhantomJS() driver.get("http://localhost:{}/static/visualiser/index.html#tab=graph".format(port)) # Check initial state. hide_done_checkbox = driver.find_element_by_css_selector("#hideDoneCheckbox") assert hide_done_checkbox.is_selected() is False # Change invert checkbox. hide_done_checkbox.click() # Now refresh page and check. Invert checkbox shoud be checked. driver.refresh() # Once page refreshed we have to find all selectors again. hide_done_checkbox = driver.find_element_by_css_selector("#hideDoneCheckbox") assert hide_done_checkbox.is_selected() def test_keeps_visualisation_type_after_page_refresh(self): port = self.get_http_port() driver = webdriver.PhantomJS() driver.get("http://localhost:{}/static/visualiser/index.html#tab=graph".format(port)) # Check initial state. svg_radio = driver.find_element_by_css_selector("input[value=svg]") assert svg_radio.is_selected() # Change vistype to d3 by clicking on its label. d3_radio = driver.find_element_by_css_selector("input[value=d3]") d3_radio.find_element_by_xpath("..").click() # Now refresh page and check. D3 checkbox shoud be checked. driver.refresh() # Once page refreshed we have to find all selectors again. d3_radio = driver.find_element_by_css_selector("input[value=d3]") assert d3_radio.is_selected() def test_synchronizes_fields_on_graph_tab(self): # Check fields population if tasks tab was opened by direct link. port = self.get_http_port() driver = webdriver.PhantomJS() url = "http://localhost:{}/static/visualiser/index.html#tab=graph&taskId=1&invert=1&hideDone=1&visType=svg".format(port) driver.get(url) # Check task id input task_id_input = driver.find_element_by_css_selector("#js-task-id") assert task_id_input.get_attribute("value") == "1" # Check Show Upstream Dependencies checkbox. invert_checkbox = driver.find_element_by_css_selector("#invertCheckbox") assert invert_checkbox.is_selected() # Check Hide Done checkbox. hide_done_checkbox = driver.find_element_by_css_selector("#hideDoneCheckbox") assert hide_done_checkbox.is_selected() svg_radio = driver.find_element_by_css_selector("input[value=svg]") assert svg_radio.get_attribute("checked") def _get_cell_value(self, elem, row, column): tr = elem.find_elements_by_css_selector("#taskTable tbody tr")[row] td = tr.find_elements_by_css_selector("td")[column] return td.text # --------------------------------------------------------------------------- # Code for generating a tree of tasks with some failures. def generate_task_families(task_class, n): """ Generate n copies of a task with different task_family names. :param task_class: a subclass of `luigi.Task` :param n: number of copies of `task_class` to create :return: Dictionary of task_family => task_class """ ret = {} for i in range(n): class_name = "{}_{}".format(task_class.task_family, i) ret[class_name] = type(class_name, (task_class,), {}) return ret class UberTask(luigi.Task): """ A task which depends on n copies of a configurable subclass. """ _done = False base_task = luigi.TaskParameter() x = luigi.Parameter() copies = luigi.IntParameter() def requires(self): task_families = generate_task_families(self.base_task, self.copies) for class_name in task_families: yield task_families[class_name](x=self.x) def complete(self): return self._done def run(self): self._done = True def popmin(a, b): """ popmin(a, b) -> (i, a', b') where i is min(a[0], b[0]) and a'/b' are the results of removing i from the relevant sequence. """ if len(a) == 0: return b[0], a, b[1:] elif len(b) == 0: return a[0], a[1:], b elif a[0] > b[0]: return b[0], a, b[1:] else: return a[0], a[1:], b class MemoryTarget(luigi.Target): def __init__(self): self.box = None def exists(self): return self.box is not None class MergeSort(luigi.Task): x = luigi.Parameter(description="A string to be sorted") def __init__(self, *args, **kwargs): super(MergeSort, self).__init__(*args, **kwargs) self.result = MemoryTarget() def requires(self): # Allows us to override behaviour in subclasses cls = self.__class__ if len(self.x) > 1: p = len(self.x) // 2 return [cls(self.x[:p]), cls(self.x[p:])] def output(self): return self.result def run(self): if len(self.x) > 1: list_1, list_2 = (x.box for x in self.input()) s = [] while list_1 or list_2: item, list_1, list_2 = popmin(list_1, list_2) s.append(item) else: s = self.x self.result.box = "".join(s) class FailingMergeSort(MergeSort): """ Simply fail if the string to sort starts with ' '. """ fail_probability = luigi.FloatParameter(default=0.0) def run(self): if self.x[0] == " ": raise Exception("I failed") else: return super(FailingMergeSort, self).run() if __name__ == "__main__": x = "I scream for ice cream" task = UberTask(base_task=FailingMergeSort, x=x, copies=4) luigi.build([task], workers=1, scheduler_port=8082) ================================================ FILE: test/worker_external_task_test.py ================================================ # Copyright (c) 2015 # # Licensed under the Apache License, Version 2.0 (the "License"); you may not # use this file except in compliance with the License. You may obtain a copy of # the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations under # the License. import os import shutil import tempfile from helpers import unittest, with_config from mock import patch import luigi import luigi.server import luigi.task import luigi.worker from luigi.local_target import LocalTarget from luigi.scheduler import Scheduler class TestExternalFileTask(luigi.ExternalTask): """Mocking tasks is a pain, so touch a file instead""" path = luigi.Parameter() times_to_call = luigi.IntParameter() def __init__(self, *args, **kwargs): super(TestExternalFileTask, self).__init__(*args, **kwargs) self.times_called = 0 def complete(self): """ Create the file we need after a number of preconfigured attempts """ self.times_called += 1 if self.times_called >= self.times_to_call: open(self.path, "a").close() return os.path.exists(self.path) def output(self): return LocalTarget(path=self.path) class TestTask(luigi.Task): """ Requires a single file dependency """ tempdir = luigi.Parameter() complete_after = luigi.IntParameter() def __init__(self, *args, **kwargs): super(TestTask, self).__init__(*args, **kwargs) self.output_path = os.path.join(self.tempdir, "test.output") self.dep_path = os.path.join(self.tempdir, "test.dep") self.dependency = TestExternalFileTask(path=self.dep_path, times_to_call=self.complete_after) def requires(self): yield self.dependency def output(self): return LocalTarget(path=self.output_path) def run(self): open(self.output_path, "a").close() class WorkerExternalTaskTest(unittest.TestCase): def setUp(self): self.tempdir = tempfile.mkdtemp(prefix="luigi-test-") def tearDown(self): shutil.rmtree(self.tempdir) def _assert_complete(self, tasks): for t in tasks: self.assert_(t.complete()) def _build(self, tasks): with self._make_worker() as w: for t in tasks: w.add(t) w.run() def _make_worker(self): self.scheduler = Scheduler(prune_on_get_work=True) return luigi.worker.Worker(scheduler=self.scheduler, worker_processes=1) def test_external_dependency_already_complete(self): """ Test that the test task completes when its dependency exists at the start of the execution. """ test_task = TestTask(tempdir=self.tempdir, complete_after=1) luigi.build([test_task], local_scheduler=True) assert os.path.exists(test_task.dep_path) assert os.path.exists(test_task.output_path) # complete() is called once per failure, twice per success assert test_task.dependency.times_called == 2 @with_config({"worker": {"retry_external_tasks": "true"}, "scheduler": {"retry_delay": "0.0"}}) def test_external_dependency_gets_rechecked(self): """ Test that retry_external_tasks re-checks external tasks """ assert luigi.worker.worker().retry_external_tasks is True test_task = TestTask(tempdir=self.tempdir, complete_after=10) self._build([test_task]) assert os.path.exists(test_task.dep_path) assert os.path.exists(test_task.output_path) self.assertGreaterEqual(test_task.dependency.times_called, 10) @with_config({"worker": {"retry_external_tasks": "true", "keep_alive": "true", "wait_interval": "0.00001"}, "scheduler": {"retry_delay": "0.01"}}) def test_external_dependency_worker_is_patient(self): """ Test that worker doesn't "give up" with keep_alive option Instead, it should sleep for random.uniform() seconds, then ask scheduler for work. """ assert luigi.worker.worker().retry_external_tasks is True with patch("random.uniform", return_value=0.001): test_task = TestTask(tempdir=self.tempdir, complete_after=5) self._build([test_task]) assert os.path.exists(test_task.dep_path) assert os.path.exists(test_task.output_path) self.assertGreaterEqual(test_task.dependency.times_called, 5) def test_external_dependency_bare(self): """ Test ExternalTask without altering global settings. """ assert luigi.worker.worker().retry_external_tasks is False test_task = TestTask(tempdir=self.tempdir, complete_after=5) scheduler = luigi.scheduler.Scheduler(retry_delay=0.01, prune_on_get_work=True) with luigi.worker.Worker(retry_external_tasks=True, scheduler=scheduler, keep_alive=True, wait_interval=0.00001, wait_jitter=0) as w: w.add(test_task) w.run() assert os.path.exists(test_task.dep_path) assert os.path.exists(test_task.output_path) self.assertGreaterEqual(test_task.dependency.times_called, 5) @with_config( { "worker": { "retry_external_tasks": "true", }, "scheduler": {"retry_delay": "0.0"}, } ) def test_external_task_complete_but_missing_dep_at_runtime(self): """ Test external task complete but has missing upstream dependency at runtime. Should not get "unfulfilled dependencies" error. """ test_task = TestTask(tempdir=self.tempdir, complete_after=3) test_task.run = NotImplemented assert len(test_task.deps()) > 0 # split up scheduling task and running to simulate runtime scenario with self._make_worker() as w: w.add(test_task) # touch output so test_task should be considered complete at runtime open(test_task.output_path, "a").close() success = w.run() self.assertTrue(success) # upstream dependency output didn't exist at runtime self.assertFalse(os.path.exists(test_task.dep_path)) ================================================ FILE: test/worker_keep_alive_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2016 VNG Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import threading from helpers import LuigiTestCase import luigi from luigi.scheduler import Scheduler from luigi.worker import Worker class WorkerKeepAliveUpstreamTest(LuigiTestCase): """ Tests related to how the worker stays alive after upstream status changes. See https://github.com/spotify/luigi/pull/1789 """ def run(self, result=None): """ Common setup code. Due to the contextmanager cant use normal setup """ self.sch = Scheduler(retry_delay=0.00000001, retry_count=2) with Worker(scheduler=self.sch, worker_id="X", keep_alive=True, wait_interval=0.1, wait_jitter=0) as w: self.w = w super(WorkerKeepAliveUpstreamTest, self).run(result) def test_alive_while_has_failure(self): """ One dependency disables and one fails """ class Disabler(luigi.Task): pass class Failer(luigi.Task): did_run = False def run(self): self.did_run = True class Wrapper(luigi.WrapperTask): def requires(self): return (Disabler(), Failer()) self.w.add(Wrapper()) disabler = Disabler().task_id failer = Failer().task_id self.sch.add_task(disabler, "FAILED", worker="X") self.sch.prune() # Make scheduler unfail the disabled task self.sch.add_task(disabler, "FAILED", worker="X") # Disable it self.sch.add_task(failer, "FAILED", worker="X") # Fail it try: t = threading.Thread(target=self.w.run) t.start() t.join(timeout=1) # Wait 1 second self.assertTrue(t.is_alive()) # It shouldn't stop trying, the failed task should be retried! self.assertFalse(Failer.did_run) # It should never have run, the cooldown is longer than a second. finally: self.sch.prune() # Make it, like die. Couldn't find a more forceful way to do this. t.join(timeout=1) # Wait 1 second assert not t.is_alive() def test_alive_while_has_success(self): """ One dependency disables and one succeeds """ # TODO: Fix copy paste mess class Disabler(luigi.Task): pass class Succeeder(luigi.Task): did_run = False def run(self): self.did_run = True class Wrapper(luigi.WrapperTask): def requires(self): return (Disabler(), Succeeder()) self.w.add(Wrapper()) disabler = Disabler().task_id succeeder = Succeeder().task_id self.sch.add_task(disabler, "FAILED", worker="X") self.sch.prune() # Make scheduler unfail the disabled task self.sch.add_task(disabler, "FAILED", worker="X") # Disable it self.sch.add_task(succeeder, "DONE", worker="X") # Fail it try: t = threading.Thread(target=self.w.run) t.start() t.join(timeout=1) # Wait 1 second self.assertFalse(t.is_alive()) # The worker should think that it should stop ... # ... because in this case the only work remaining depends on DISABLED tasks, # hence it's not worth considering the wrapper task as a PENDING task to # keep the worker alive anymore. self.assertFalse(Succeeder.did_run) # It should never have run, it succeeded already finally: self.sch.prune() # This shouldnt be necessary in this version, but whatevs t.join(timeout=1) # Wait 1 second assert not t.is_alive() ================================================ FILE: test/worker_multiprocess_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import logging from helpers import unittest from mock import Mock import luigi.notifications import luigi.worker from luigi import Parameter, RemoteScheduler, Task from luigi.worker import Worker luigi.notifications.DEBUG = True class DummyTask(Task): param = Parameter() def __init__(self, *args, **kwargs): super(DummyTask, self).__init__(*args, **kwargs) self.has_run = False def complete(self): old_value = self.has_run self.has_run = True return old_value def run(self): logging.debug("%s - setting has_run", self) self.has_run = True class MultiprocessWorkerTest(unittest.TestCase): def run(self, result=None): self.scheduler = RemoteScheduler() self.scheduler.add_worker = Mock() self.scheduler.add_task = Mock() with Worker(scheduler=self.scheduler, worker_id="X", worker_processes=2) as worker: self.worker = worker super(MultiprocessWorkerTest, self).run(result) def gw_res(self, pending, task_id): return dict(n_pending_tasks=pending, task_id=task_id, running_tasks=0, n_unique_pending=0) def test_positive_path(self): a = DummyTask("a") b = DummyTask("b") class MultipleRequirementTask(DummyTask): def requires(self): return [a, b] c = MultipleRequirementTask("C") self.assertTrue(self.worker.add(c)) self.scheduler.get_work = Mock( side_effect=[self.gw_res(3, a.task_id), self.gw_res(2, b.task_id), self.gw_res(1, c.task_id), self.gw_res(0, None), self.gw_res(0, None)] ) self.assertTrue(self.worker.run()) self.assertTrue(c.has_run) def test_path_with_task_failures(self): class FailingTask(DummyTask): def run(self): raise Exception("I am failing") a = FailingTask("a") b = FailingTask("b") class MultipleRequirementTask(DummyTask): def requires(self): return [a, b] c = MultipleRequirementTask("C") self.assertTrue(self.worker.add(c)) self.scheduler.get_work = Mock( side_effect=[self.gw_res(3, a.task_id), self.gw_res(2, b.task_id), self.gw_res(1, c.task_id), self.gw_res(0, None), self.gw_res(0, None)] ) self.assertFalse(self.worker.run()) class SingleWorkerMultiprocessTest(unittest.TestCase): def test_default_multiprocessing_behavior(self): with Worker(worker_processes=1) as worker: task = DummyTask("a") task_process = worker._create_task_process(task) self.assertFalse(task_process.use_multiprocessing) def test_force_multiprocessing(self): with Worker(worker_processes=1, force_multiprocessing=True) as worker: task = DummyTask("a") task_process = worker._create_task_process(task) self.assertTrue(task_process.use_multiprocessing) ================================================ FILE: test/worker_parallel_scheduling_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import contextlib import gc import os import pickle import time import mock import psutil from helpers import unittest import luigi from luigi.task_status import UNKNOWN from luigi.worker import Worker def running_children(): children = set() process = psutil.Process(os.getpid()) for child in process.children(): if child.is_running(): children.add(child.pid) return children @contextlib.contextmanager def pause_gc(): if not gc.isenabled(): yield try: gc.disable() yield finally: gc.enable() class SlowCompleteWrapper(luigi.WrapperTask): def requires(self): return [SlowCompleteTask(i) for i in range(4)] class SlowCompleteTask(luigi.Task): n = luigi.IntParameter() def complete(self): time.sleep(0.1) return True class OverlappingSelfDependenciesTask(luigi.Task): n = luigi.IntParameter() k = luigi.IntParameter() def complete(self): return self.n < self.k or self.k == 0 def requires(self): return [OverlappingSelfDependenciesTask(self.n - 1, k) for k in range(self.k + 1)] class ExceptionCompleteTask(luigi.Task): def complete(self): assert False class ExceptionRequiresTask(luigi.Task): def requires(self): assert False class UnpicklableExceptionTask(luigi.Task): def complete(self): class UnpicklableException(Exception): pass raise UnpicklableException() class ParallelSchedulingTest(unittest.TestCase): def setUp(self): self.sch = mock.Mock() self.w = Worker(scheduler=self.sch, worker_id="x") def added_tasks(self, status): return [kw["task_id"] for args, kw in self.sch.add_task.call_args_list if kw["status"] == status] def test_number_of_processes(self): import multiprocessing real_pool = multiprocessing.Pool(1) with mock.patch("multiprocessing.Pool") as mocked_pool: mocked_pool.return_value = real_pool self.w.add(OverlappingSelfDependenciesTask(n=1, k=1), multiprocess=True, processes=1234) mocked_pool.assert_called_once_with(processes=1234) def test_zero_processes(self): import multiprocessing real_pool = multiprocessing.Pool(1) with mock.patch("multiprocessing.Pool") as mocked_pool: mocked_pool.return_value = real_pool self.w.add(OverlappingSelfDependenciesTask(n=1, k=1), multiprocess=True, processes=0) mocked_pool.assert_called_once_with(processes=None) def test_children_terminated(self): before_children = running_children() with pause_gc(): self.w.add( OverlappingSelfDependenciesTask(5, 2), multiprocess=True, ) self.assertLessEqual(running_children(), before_children) def test_multiprocess_scheduling_with_overlapping_dependencies(self): self.w.add(OverlappingSelfDependenciesTask(5, 2), True) self.assertEqual(15, self.sch.add_task.call_count) self.assertEqual( set( ( OverlappingSelfDependenciesTask(n=1, k=1).task_id, OverlappingSelfDependenciesTask(n=2, k=1).task_id, OverlappingSelfDependenciesTask(n=2, k=2).task_id, OverlappingSelfDependenciesTask(n=3, k=1).task_id, OverlappingSelfDependenciesTask(n=3, k=2).task_id, OverlappingSelfDependenciesTask(n=4, k=1).task_id, OverlappingSelfDependenciesTask(n=4, k=2).task_id, OverlappingSelfDependenciesTask(n=5, k=2).task_id, ) ), set(self.added_tasks("PENDING")), ) self.assertEqual( set( ( OverlappingSelfDependenciesTask(n=0, k=0).task_id, OverlappingSelfDependenciesTask(n=0, k=1).task_id, OverlappingSelfDependenciesTask(n=1, k=0).task_id, OverlappingSelfDependenciesTask(n=1, k=2).task_id, OverlappingSelfDependenciesTask(n=2, k=0).task_id, OverlappingSelfDependenciesTask(n=3, k=0).task_id, OverlappingSelfDependenciesTask(n=4, k=0).task_id, ) ), set(self.added_tasks("DONE")), ) @mock.patch("luigi.notifications.send_error_email") def test_raise_exception_in_complete(self, send): self.w.add(ExceptionCompleteTask(), multiprocess=True) send.check_called_once() self.assertEqual(UNKNOWN, self.sch.add_task.call_args[1]["status"]) self.assertFalse(self.sch.add_task.call_args[1]["runnable"]) self.assertTrue("assert False" in send.call_args[0][1]) @mock.patch("luigi.notifications.send_error_email") def test_raise_unpicklable_exception_in_complete(self, send): # verify exception can't be pickled self.assertRaises(Exception, UnpicklableExceptionTask().complete) try: UnpicklableExceptionTask().complete() except Exception as e: ex = e self.assertRaises((pickle.PicklingError, AttributeError), pickle.dumps, ex) # verify this can run async self.w.add(UnpicklableExceptionTask(), multiprocess=True) send.check_called_once() self.assertEqual(UNKNOWN, self.sch.add_task.call_args[1]["status"]) self.assertFalse(self.sch.add_task.call_args[1]["runnable"]) self.assertTrue("raise UnpicklableException()" in send.call_args[0][1]) @mock.patch("luigi.notifications.send_error_email") def test_raise_exception_in_requires(self, send): self.w.add(ExceptionRequiresTask(), multiprocess=True) send.check_called_once() self.assertEqual(UNKNOWN, self.sch.add_task.call_args[1]["status"]) self.assertFalse(self.sch.add_task.call_args[1]["runnable"]) ================================================ FILE: test/worker_scheduler_com_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2017 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import contextlib import os import shutil import tempfile import threading import time from helpers import LuigiTestCase import luigi from luigi.scheduler import Scheduler from luigi.worker import Worker class WorkerSchedulerCommunicationTest(LuigiTestCase): """ Tests related to communication between Worker and Scheduler that is based on the ping polling. See https://github.com/spotify/luigi/pull/1993 """ def run(self, result=None): self.sch = Scheduler() with Worker(scheduler=self.sch, worker_id="X", ping_interval=1, max_reschedules=0) as w: self.w = w # also save scheduler's worker struct self.sw = self.sch._state.get_worker(self.w._id) super(WorkerSchedulerCommunicationTest, self).run(result) def wrapper_task(test_self): tmp = tempfile.mkdtemp() class MyTask(luigi.Task): n = luigi.IntParameter() delay = 3 def output(self): basename = "%s_%s.txt" % (self.__class__.__name__, self.n) return luigi.LocalTarget(os.path.join(tmp, basename)) def run(self): time.sleep(self.delay) with self.output().open("w") as f: f.write("content\n") class Wrapper(MyTask): delay = 0 def requires(self): return [MyTask(n=n) for n in range(self.n)] return Wrapper, tmp def test_message_handling(self): # add some messages for that worker for i in range(10): self.sw.add_rpc_message("foo", i=i) self.assertEqual(10, len(self.sw.rpc_messages)) self.assertEqual(9, self.sw.rpc_messages[-1]["kwargs"]["i"]) # fetch msgs = self.sw.fetch_rpc_messages() self.assertEqual(0, len(self.sw.rpc_messages)) self.assertEqual(9, msgs[-1]["kwargs"]["i"]) def test_ping_content(self): # add some messages for that worker for i in range(10): self.sw.add_rpc_message("foo", i=i) # ping the scheduler and check the result res = self.sch.ping(worker=self.w._id) self.assertIn("rpc_messages", res) msgs = res["rpc_messages"] self.assertEqual(10, len(msgs)) self.assertEqual("foo", msgs[-1]["name"]) self.assertEqual(9, msgs[-1]["kwargs"]["i"]) # there should be no message left self.assertEqual(0, len(self.sw.rpc_messages)) @contextlib.contextmanager def run_wrapper(self, n): # assign the wrapper task to the worker Wrapper, tmp = self.wrapper_task() wrapper = Wrapper(n=n) self.assertTrue(self.w.add(wrapper)) # check the initial number of worker processes self.assertEqual(1, self.w.worker_processes) # run the task in a thread and while running, increase the number of worker processes # via an rpc message t = threading.Thread(target=self.w.run) t.start() # yield yield wrapper, t # finally, check that thread is done self.assertFalse(t.is_alive()) # cleanup the tmp dir shutil.rmtree(tmp) def test_dispatch_valid_message(self): with self.run_wrapper(3) as (wrapper, t): # each of the wrapper task's tasks runs 3 seconds, and the ping/message dispatch # interval is 1 second, so it should be safe to wait 1 second here, add the message # which is then fetched by the keep alive thread and dispatched, so after additional 3 # seconds, the worker will have a changed number of processes t.join(1) self.sch.set_worker_processes(self.w._id, 2) t.join(3) self.assertEqual(2, self.w.worker_processes) # after additional 3 seconds, the wrapper task + all required tasks should be completed t.join(3) self.assertTrue(all(task.complete() for task in wrapper.requires())) self.assertTrue(wrapper.complete()) def test_dispatch_invalid_message(self): # this test is identical to test_dispatch_valid_message, except that the number of processes # is not increased during running as we send an invalid rpc message # in addition, the wrapper will only have two requirements with self.run_wrapper(2) as (wrapper, t): # timing info as above t.join(1) self.sw.add_rpc_message("set_worker_processes_not_there", n=2) t.join(3) self.assertEqual(1, self.w.worker_processes) # after additional 3 seconds, the wrapper task and all required tasks should be completed t.join(3) self.assertTrue(all(task.complete() for task in wrapper.requires())) self.assertTrue(wrapper.complete()) def test_dispatch_unregistered_message(self): # this test is identical to test_dispatch_valid_message, except that the number of processes # is not increased during running as we disable the particular callback to work as a # callback, so we want to achieve sth like # self.w.set_worker_processes.is_rpc_message_callback = False # but this is not possible in py 2 due to wrapped method lookup, see # http://stackoverflow.com/questions/9523370/adding-attributes-to-instance-methods-in-python set_worker_processes_orig = self.w.set_worker_processes def set_worker_processes_replacement(*args, **kwargs): return set_worker_processes_orig(*args, **kwargs) self.w.set_worker_processes = set_worker_processes_replacement self.assertFalse(getattr(self.w.set_worker_processes, "is_rpc_message_callback", False)) with self.run_wrapper(2) as (wrapper, t): # timing info as above t.join(1) self.sw.add_rpc_message("set_worker_processes", n=2) t.join(3) self.assertEqual(1, self.w.worker_processes) # after additional 3 seconds, the wrapper task and all required tasks should be completed t.join(3) self.assertTrue(all(task.complete() for task in wrapper.requires())) self.assertTrue(wrapper.complete()) ================================================ FILE: test/worker_task_process_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import multiprocessing from helpers import LuigiTestCase, temporary_unloaded_module import luigi from luigi.worker import Worker class ContextManagedTaskProcessTest(LuigiTestCase): def _test_context_manager(self, force_multiprocessing): CONTEXT_MANAGER_MODULE = b""" class MyContextManager: def __init__(self, task_process): self.task = task_process.task def __enter__(self): assert not self.task.run_event.is_set(), "the task should not have run yet" self.task.enter_event.set() return self def __exit__(self, exc_type=None, exc_value=None, traceback=None): assert self.task.run_event.is_set(), "the task should have run" self.task.exit_event.set() """ class DummyEventRecordingTask(luigi.Task): def __init__(self, *args, **kwargs): self.enter_event = multiprocessing.Event() self.exit_event = multiprocessing.Event() self.run_event = multiprocessing.Event() super(DummyEventRecordingTask, self).__init__(*args, **kwargs) def run(self): assert self.enter_event.is_set(), "the context manager should have been entered" assert not self.exit_event.is_set(), "the context manager should not have been exited yet" assert not self.run_event.is_set(), "the task should not have run yet" self.run_event.set() def complete(self): return self.run_event.is_set() with temporary_unloaded_module(CONTEXT_MANAGER_MODULE) as module_name: t = DummyEventRecordingTask() w = Worker(task_process_context=module_name + ".MyContextManager", force_multiprocessing=force_multiprocessing) w.add(t) self.assertTrue(w.run()) self.assertTrue(t.complete()) self.assertTrue(t.enter_event.is_set()) self.assertTrue(t.exit_event.is_set()) def test_context_manager_without_multiprocessing(self): self._test_context_manager(False) def test_context_manager_with_multiprocessing(self): self._test_context_manager(True) ================================================ FILE: test/worker_task_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import multiprocessing import sys from subprocess import check_call from time import sleep import mock from helpers import LuigiTestCase, StringContaining from psutil import Process import luigi import luigi.date_interval import luigi.notifications from luigi.mock import MockTarget from luigi.scheduler import DONE, FAILED from luigi.worker import TaskException, TaskProcess luigi.notifications.DEBUG = True class WorkerTaskTest(LuigiTestCase): def test_constructor(self): class MyTask(luigi.Task): # Test overriding the constructor without calling the superconstructor # This is a simple mistake but caused an error that was very hard to understand def __init__(self): pass def f(): luigi.build([MyTask()], local_scheduler=True) self.assertRaises(TaskException, f) def test_run_none(self): def f(): luigi.build([None], local_scheduler=True) self.assertRaises(TaskException, f) class TaskProcessTest(LuigiTestCase): def test_update_result_queue_on_success(self): # IMO this test makes no sense as it tests internal behavior and have # already broken once during internal non-changing refactoring class SuccessTask(luigi.Task): def on_success(self): return "test success expl" task = SuccessTask() result_queue = multiprocessing.Queue() task_process = TaskProcess(task, 1, result_queue, mock.Mock()) with mock.patch.object(result_queue, "put") as mock_put: task_process.run() mock_put.assert_called_once_with((task.task_id, DONE, "test success expl", [], None)) def test_update_result_queue_on_failure(self): # IMO this test makes no sense as it tests internal behavior and have # already broken once during internal non-changing refactoring class FailTask(luigi.Task): def run(self): raise BaseException("Uh oh.") def on_failure(self, exception): return "test failure expl" task = FailTask() result_queue = multiprocessing.Queue() task_process = TaskProcess(task, 1, result_queue, mock.Mock()) with mock.patch.object(result_queue, "put") as mock_put: task_process.run() mock_put.assert_called_once_with((task.task_id, FAILED, "test failure expl", [], [])) def test_fail_on_false_complete(self): class NeverCompleteTask(luigi.Task): def complete(self): return False task = NeverCompleteTask() result_queue = multiprocessing.Queue() task_process = TaskProcess(task, 1, result_queue, mock.Mock(), check_complete_on_run=True) with mock.patch.object(result_queue, "put") as mock_put: task_process.run() mock_put.assert_called_once_with((task.task_id, FAILED, StringContaining("finished running, but complete() is still returning false"), [], None)) def test_fail_on_unfulfilled_dependencies(self): class NeverCompleteTask(luigi.Task): def complete(self): return False class A(NeverCompleteTask): def output(self): return [] class B(NeverCompleteTask): def output(self): return MockTarget("foo-B") class C(NeverCompleteTask): def output(self): return [MockTarget("foo-C1"), MockTarget("foo-C2")] class Main(NeverCompleteTask): def requires(self): return [A(), B(), C()] task = Main() result_queue = multiprocessing.Queue() task_process = TaskProcess(task, 1, result_queue, mock.Mock()) with mock.patch.object(result_queue, "put") as mock_put: task_process.run() expected_missing = [A().task_id, f"{B().task_id} (foo-B)", f"{C().task_id} (foo-C1, foo-C2)"] mock_put.assert_called_once_with( ( task.task_id, FAILED, StringContaining(f"Unfulfilled dependencies at run time: {', '.join(expected_missing)}"), expected_missing, [], ) ) def test_cleanup_children_on_terminate(self): """ Subprocesses spawned by tasks should be terminated on terminate """ class HangingSubprocessTask(luigi.Task): def run(self): python = sys.executable check_call([python, "-c", "while True: pass"]) task = HangingSubprocessTask() queue = mock.Mock() worker_id = 1 task_process = TaskProcess(task, worker_id, queue, mock.Mock()) task_process.start() parent = Process(task_process.pid) while not parent.children(): # wait for child process to startup sleep(0.01) [child] = parent.children() task_process.terminate() child.wait(timeout=1.0) # wait for terminate to complete self.assertFalse(parent.is_running()) self.assertFalse(child.is_running()) def test_disable_worker_timeout(self): """ When a task sets worker_timeout explicitly to 0, it should disable the timeout, even if it is configured globally. """ class Task(luigi.Task): worker_timeout = 0 task_process = TaskProcess( task=Task(), worker_id=1, result_queue=mock.Mock(), status_reporter=mock.Mock(), worker_timeout=10, ) self.assertEqual(task_process.worker_timeout, 0) ================================================ FILE: test/worker_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import email.parser import functools import logging import os import shutil import signal import tempfile import threading import time import mock import psutil from helpers import LuigiTestCase, skipOnTravisAndGithubActions, temporary_unloaded_module, unittest, with_config import luigi.notifications import luigi.task_register import luigi.worker from luigi import Event, ExternalTask, RemoteScheduler, Task from luigi.cmdline import luigi_run from luigi.mock import MockFileSystem, MockTarget from luigi.rpc import RPCError from luigi.scheduler import Scheduler from luigi.worker import Worker luigi.notifications.DEBUG = True class DummyTask(Task): def __init__(self, *args, **kwargs): super(DummyTask, self).__init__(*args, **kwargs) self.has_run = False def complete(self): return self.has_run def run(self): logging.debug("%s - setting has_run", self) self.has_run = True class DynamicDummyTask(Task): p = luigi.Parameter() sleep = luigi.FloatParameter(default=0.5, significant=False) def output(self): return luigi.LocalTarget(self.p) def run(self): with self.output().open("w") as f: f.write("Done!") time.sleep(self.sleep) # so we can benchmark & see if parallelization works class DynamicDummyTaskWithNamespace(DynamicDummyTask): task_namespace = "banana" class DynamicRequires(Task): p = luigi.Parameter() use_banana_task = luigi.BoolParameter(default=False) def output(self): return luigi.LocalTarget(os.path.join(self.p, "parent")) def run(self): if self.use_banana_task: task_cls = DynamicDummyTaskWithNamespace else: task_cls = DynamicDummyTask dummy_targets = yield [task_cls(os.path.join(self.p, str(i))) for i in range(5)] dummy_targets += yield [task_cls(os.path.join(self.p, str(i))) for i in range(5, 7)] with self.output().open("w") as f: for i, d in enumerate(dummy_targets): for line in d.open("r"): print("%d: %s" % (i, line.strip()), file=f) class DynamicRequiresWrapped(Task): p = luigi.Parameter() def output(self): return luigi.LocalTarget(os.path.join(self.p, "parent")) def run(self): reqs = [DynamicDummyTask(p=os.path.join(self.p, "%s.txt" % i), sleep=0.0) for i in range(10)] # yield again as DynamicRequires yield luigi.DynamicRequirements(reqs) # and again with a custom complete function that does base name comparisons def custom_complete(complete_fn): if not complete_fn(reqs[0]): return False paths = [task.output().path for task in reqs] basenames = os.listdir(os.path.dirname(paths[0])) self._custom_complete_called = True self._custom_complete_result = all(os.path.basename(path) in basenames for path in paths) return self._custom_complete_result yield luigi.DynamicRequirements(reqs, custom_complete) with self.output().open("w") as f: f.write("Done!") class DynamicRequiresOtherModule(Task): p = luigi.Parameter() def output(self): return luigi.LocalTarget(os.path.join(self.p, "baz")) def run(self): import other_module other_target_foo = yield other_module.OtherModuleTask(os.path.join(self.p, "foo")) # NOQA other_target_bar = yield other_module.OtherModuleTask(os.path.join(self.p, "bar")) # NOQA with self.output().open("w") as f: f.write("Done!") class DummyErrorTask(Task): retry_index = 0 def run(self): self.retry_index += 1 raise Exception("Retry index is %s for %s" % (self.retry_index, self.task_family)) class WorkerTest(LuigiTestCase): def run(self, result=None): self.sch = Scheduler(retry_delay=100, remove_delay=1000, worker_disconnect_delay=10, stable_done_cooldown_secs=0) self.time = time.time with Worker(scheduler=self.sch, worker_id="X") as w, Worker(scheduler=self.sch, worker_id="Y") as w2: self.w = w self.w2 = w2 super(WorkerTest, self).run(result) if time.time != self.time: time.time = self.time def setTime(self, t): time.time = lambda: t def test_dep(self): class A(Task): def run(self): self.has_run = True def complete(self): return self.has_run a = A() class B(Task): def requires(self): return a def run(self): self.has_run = True def complete(self): return self.has_run b = B() a.has_run = False b.has_run = False self.assertTrue(self.w.add(b)) self.assertTrue(self.w.run()) self.assertTrue(a.has_run) self.assertTrue(b.has_run) def test_external_dep(self): class A(ExternalTask): def complete(self): return False a = A() class B(Task): def requires(self): return a def run(self): self.has_run = True def complete(self): return self.has_run b = B() a.has_run = False b.has_run = False self.assertTrue(self.w.add(b)) self.assertTrue(self.w.run()) self.assertFalse(a.has_run) self.assertFalse(b.has_run) def test_externalized_dep(self): class A(Task): has_run = False def run(self): self.has_run = True def complete(self): return self.has_run a = A() class B(A): def requires(self): return luigi.task.externalize(a) b = B() self.assertTrue(self.w.add(b)) self.assertTrue(self.w.run()) self.assertFalse(a.has_run) self.assertFalse(b.has_run) def test_legacy_externalized_dep(self): class A(Task): has_run = False def run(self): self.has_run = True def complete(self): return self.has_run a = A() a.run = NotImplemented class B(A): def requires(self): return a b = B() self.assertTrue(self.w.add(b)) self.assertTrue(self.w.run()) self.assertFalse(a.has_run) self.assertFalse(b.has_run) def test_type_error_in_tracking_run_deprecated(self): class A(Task): num_runs = 0 def complete(self): return False def run(self, tracking_url_callback=None): self.num_runs += 1 raise TypeError("bad type") a = A() self.assertTrue(self.w.add(a)) self.assertFalse(self.w.run()) # Should only run and fail once, not retry because of the type error self.assertEqual(1, a.num_runs) def test_tracking_url(self): tracking_url = "http://test_url.com/" class A(Task): has_run = False def complete(self): return self.has_run def run(self): self.set_tracking_url(tracking_url) self.has_run = True a = A() self.assertTrue(self.w.add(a)) self.assertTrue(self.w.run()) tasks = self.sch.task_list("DONE", "") self.assertEqual(1, len(tasks)) self.assertEqual(tracking_url, tasks[a.task_id]["tracking_url"]) def test_fail(self): class CustomException(BaseException): def __init__(self, msg): self.msg = msg class A(Task): def run(self): self.has_run = True raise CustomException("bad things") def complete(self): return self.has_run a = A() class B(Task): def requires(self): return a def run(self): self.has_run = True def complete(self): return self.has_run b = B() a.has_run = False b.has_run = False self.assertTrue(self.w.add(b)) self.assertFalse(self.w.run()) self.assertTrue(a.has_run) self.assertFalse(b.has_run) def test_unknown_dep(self): # see related test_remove_dep test (grep for it) class A(ExternalTask): def complete(self): return False class C(Task): def complete(self): return True def get_b(dep): class B(Task): def requires(self): return dep def run(self): self.has_run = True def complete(self): return False b = B() b.has_run = False return b b_a = get_b(A()) b_c = get_b(C()) self.assertTrue(self.w.add(b_a)) # So now another worker goes in and schedules C -> B # This should remove the dep A -> B but will screw up the first worker self.assertTrue(self.w2.add(b_c)) self.assertFalse(self.w.run()) # should not run anything - the worker should detect that A is broken self.assertFalse(b_a.has_run) # not sure what should happen?? # self.w2.run() # should run B since C is fulfilled # self.assertTrue(b_c.has_run) def test_unfulfilled_dep(self): class A(Task): def complete(self): return self.done def run(self): self.done = True def get_b(a): class B(A): def requires(self): return a b = B() b.done = False a.done = True return b a = A() b = get_b(a) self.assertTrue(self.w.add(b)) a.done = False self.w.run() self.assertTrue(a.complete()) self.assertTrue(b.complete()) def test_check_unfulfilled_deps_config(self): class A(Task): i = luigi.IntParameter() def __init__(self, *args, **kwargs): super(A, self).__init__(*args, **kwargs) self.complete_count = 0 self.has_run = False def complete(self): self.complete_count += 1 return self.has_run def run(self): self.has_run = True class B(A): def requires(self): return A(i=self.i) # test the enabled features with Worker(scheduler=self.sch, worker_id="1") as w: w._config.check_unfulfilled_deps = True a1 = A(i=1) b1 = B(i=1) self.assertTrue(w.add(b1)) self.assertEqual(a1.complete_count, 1) self.assertEqual(b1.complete_count, 1) w.run() self.assertTrue(a1.complete()) self.assertTrue(b1.complete()) self.assertEqual(a1.complete_count, 3) self.assertEqual(b1.complete_count, 2) # test the disabled features with Worker(scheduler=self.sch, worker_id="2") as w: w._config.check_unfulfilled_deps = False a2 = A(i=2) b2 = B(i=2) self.assertTrue(w.add(b2)) self.assertEqual(a2.complete_count, 1) self.assertEqual(b2.complete_count, 1) w.run() self.assertTrue(a2.complete()) self.assertTrue(b2.complete()) self.assertEqual(a2.complete_count, 2) self.assertEqual(b2.complete_count, 2) def test_cache_task_completion_config(self): class A(Task): i = luigi.IntParameter() def __init__(self, *args, **kwargs): super(A, self).__init__(*args, **kwargs) self.complete_count = 0 self.has_run = False def complete(self): self.complete_count += 1 return self.has_run def run(self): self.has_run = True class B(A): def run(self): yield A(i=self.i + 0) yield A(i=self.i + 1) yield A(i=self.i + 2) self.has_run = True # test with enabled cache_task_completion with Worker(scheduler=self.sch, worker_id="2", cache_task_completion=True) as w: b0 = B(i=0) a0 = A(i=0) a1 = A(i=1) a2 = A(i=2) self.assertTrue(w.add(b0)) # a's are required dynamically, so their counts must be 0 self.assertEqual(b0.complete_count, 1) self.assertEqual(a0.complete_count, 0) self.assertEqual(a1.complete_count, 0) self.assertEqual(a2.complete_count, 0) w.run() # the complete methods of a's yielded first in b's run method were called equally often self.assertEqual(b0.complete_count, 1) self.assertEqual(a0.complete_count, 2) self.assertEqual(a1.complete_count, 2) self.assertEqual(a2.complete_count, 2) # test with disabled cache_task_completion with Worker(scheduler=self.sch, worker_id="2", cache_task_completion=False) as w: b10 = B(i=10) a10 = A(i=10) a11 = A(i=11) a12 = A(i=12) self.assertTrue(w.add(b10)) # a's are required dynamically, so their counts must be 0 self.assertEqual(b10.complete_count, 1) self.assertEqual(a10.complete_count, 0) self.assertEqual(a11.complete_count, 0) self.assertEqual(a12.complete_count, 0) w.run() # the complete methods of a's yielded first in b's run method were called more often self.assertEqual(b10.complete_count, 1) self.assertEqual(a10.complete_count, 5) self.assertEqual(a11.complete_count, 4) self.assertEqual(a12.complete_count, 3) # test with enabled check_complete_on_run with Worker(scheduler=self.sch, worker_id="2", check_complete_on_run=True) as w: b20 = B(i=20) a20 = A(i=20) a21 = A(i=21) a22 = A(i=22) self.assertTrue(w.add(b20)) # a's are required dynamically, so their counts must be 0 self.assertEqual(b20.complete_count, 1) self.assertEqual(a20.complete_count, 0) self.assertEqual(a21.complete_count, 0) self.assertEqual(a22.complete_count, 0) w.run() # the complete methods of a's yielded first in b's run method were called more often self.assertEqual(b20.complete_count, 2) self.assertEqual(a20.complete_count, 6) self.assertEqual(a21.complete_count, 5) self.assertEqual(a22.complete_count, 4) def test_gets_missed_work(self): class A(Task): done = False def complete(self): return self.done def run(self): self.done = True a = A() self.assertTrue(self.w.add(a)) # simulate a missed get_work response self.assertEqual(a.task_id, self.sch.get_work(worker="X")["task_id"]) self.assertTrue(self.w.run()) self.assertTrue(a.complete()) def test_avoid_infinite_reschedule(self): class A(Task): def complete(self): return False class B(Task): def complete(self): return False def requires(self): return A() self.assertTrue(self.w.add(B())) self.assertFalse(self.w.run()) def test_fails_registering_signal(self): with mock.patch("luigi.worker.signal", spec=["signal"]): # mock will raise an attribute error getting signal.SIGUSR1 Worker() def test_allow_reschedule_with_many_missing_deps(self): class A(Task): """Task that must run twice to succeed""" i = luigi.IntParameter() runs = 0 def complete(self): return self.runs >= 2 def run(self): self.runs += 1 class B(Task): done = False def requires(self): return map(A, range(20)) def complete(self): return self.done def run(self): self.done = True b = B() w = Worker(scheduler=self.sch, worker_id="X", max_reschedules=1) self.assertTrue(w.add(b)) self.assertFalse(w.run()) # For b to be done, we must have rescheduled its dependencies to run them twice self.assertTrue(b.complete()) self.assertTrue(all(a.complete() for a in b.deps())) def test_interleaved_workers(self): class A(DummyTask): pass a = A() class B(DummyTask): def requires(self): return a ExternalB = luigi.task.externalize(B) b = B() eb = ExternalB() self.assertEqual(str(eb), "B()") sch = Scheduler(retry_delay=100, remove_delay=1000, worker_disconnect_delay=10) with Worker(scheduler=sch, worker_id="X") as w, Worker(scheduler=sch, worker_id="Y") as w2: self.assertTrue(w.add(b)) self.assertTrue(w2.add(eb)) logging.debug("RUNNING BROKEN WORKER") self.assertTrue(w2.run()) self.assertFalse(a.complete()) self.assertFalse(b.complete()) logging.debug("RUNNING FUNCTIONAL WORKER") self.assertTrue(w.run()) self.assertTrue(a.complete()) self.assertTrue(b.complete()) def test_interleaved_workers2(self): # two tasks without dependencies, one external, one not class B(DummyTask): pass ExternalB = luigi.task.externalize(B) b = B() eb = ExternalB() self.assertEqual(str(eb), "B()") sch = Scheduler(retry_delay=100, remove_delay=1000, worker_disconnect_delay=10) with Worker(scheduler=sch, worker_id="X") as w, Worker(scheduler=sch, worker_id="Y") as w2: self.assertTrue(w2.add(eb)) self.assertTrue(w.add(b)) self.assertTrue(w2.run()) self.assertFalse(b.complete()) self.assertTrue(w.run()) self.assertTrue(b.complete()) def test_interleaved_workers3(self): class A(DummyTask): def run(self): logging.debug("running A") time.sleep(0.1) super(A, self).run() a = A() class B(DummyTask): def requires(self): return a def run(self): logging.debug("running B") super(B, self).run() b = B() sch = Scheduler(retry_delay=100, remove_delay=1000, worker_disconnect_delay=10) with Worker(scheduler=sch, worker_id="X", keep_alive=True, count_uniques=True) as w: with Worker(scheduler=sch, worker_id="Y", keep_alive=True, count_uniques=True, wait_interval=0.1, wait_jitter=0.05) as w2: self.assertTrue(w.add(a)) self.assertTrue(w2.add(b)) threading.Thread(target=w.run).start() self.assertTrue(w2.run()) self.assertTrue(a.complete()) self.assertTrue(b.complete()) def test_die_for_non_unique_pending(self): class A(DummyTask): def run(self): logging.debug("running A") time.sleep(0.1) super(A, self).run() a = A() class B(DummyTask): def requires(self): return a def run(self): logging.debug("running B") super(B, self).run() b = B() sch = Scheduler(retry_delay=100, remove_delay=1000, worker_disconnect_delay=10) with Worker(scheduler=sch, worker_id="X", keep_alive=True, count_uniques=True) as w: with Worker(scheduler=sch, worker_id="Y", keep_alive=True, count_uniques=True, wait_interval=0.1, wait_jitter=0.05) as w2: self.assertTrue(w.add(b)) self.assertTrue(w2.add(b)) self.assertEqual(w._get_work()[0], a.task_id) self.assertTrue(w2.run()) self.assertFalse(a.complete()) self.assertFalse(b.complete()) def test_complete_exception(self): "Tests that a task is still scheduled if its sister task crashes in the complete() method" class A(DummyTask): def complete(self): raise Exception("doh") a = A() class C(DummyTask): pass c = C() class B(DummyTask): def requires(self): return a, c b = B() sch = Scheduler(retry_delay=100, remove_delay=1000, worker_disconnect_delay=10) with Worker(scheduler=sch, worker_id="foo") as w: self.assertFalse(w.add(b)) self.assertTrue(w.run()) self.assertFalse(b.has_run) self.assertTrue(c.has_run) self.assertFalse(a.has_run) def test_requires_exception(self): class A(DummyTask): def requires(self): raise Exception("doh") a = A() class D(DummyTask): pass d = D() class C(DummyTask): def requires(self): return d c = C() class B(DummyTask): def requires(self): return c, a b = B() sch = Scheduler(retry_delay=100, remove_delay=1000, worker_disconnect_delay=10) with Worker(scheduler=sch, worker_id="foo") as w: self.assertFalse(w.add(b)) self.assertTrue(w.run()) self.assertFalse(b.has_run) self.assertTrue(c.has_run) self.assertTrue(d.has_run) self.assertFalse(a.has_run) def test_run_csv_batch_job(self): completed = set() class CsvBatchJob(luigi.Task): values = luigi.parameter.Parameter(batch_method=",".join) has_run = False def run(self): completed.update(self.values.split(",")) self.has_run = True def complete(self): return all(value in completed for value in self.values.split(",")) tasks = [CsvBatchJob(str(i)) for i in range(10)] for task in tasks: self.assertTrue(self.w.add(task)) self.assertTrue(self.w.run()) for task in tasks: self.assertTrue(task.complete()) self.assertFalse(task.has_run) def test_run_max_batch_job(self): completed = set() class MaxBatchJob(luigi.Task): value = luigi.IntParameter(batch_method=max) has_run = False def run(self): completed.add(self.value) self.has_run = True def complete(self): return any(self.value <= ran for ran in completed) tasks = [MaxBatchJob(i) for i in range(10)] for task in tasks: self.assertTrue(self.w.add(task)) self.assertTrue(self.w.run()) for task in tasks: self.assertTrue(task.complete()) # only task number 9 should run self.assertFalse(task.has_run and task.value < 9) def test_run_batch_job_unbatched(self): completed = set() class MaxNonBatchJob(luigi.Task): value = luigi.IntParameter(batch_method=max) has_run = False batchable = False def run(self): completed.add(self.value) self.has_run = True def complete(self): return self.value in completed tasks = [MaxNonBatchJob((i,)) for i in range(10)] for task in tasks: self.assertTrue(self.w.add(task)) self.assertTrue(self.w.run()) for task in tasks: self.assertTrue(task.complete()) self.assertTrue(task.has_run) def test_run_batch_job_limit_batch_size(self): completed = set() runs = [] class CsvLimitedBatchJob(luigi.Task): value = luigi.parameter.Parameter(batch_method=",".join) has_run = False max_batch_size = 4 def run(self): completed.update(self.value.split(",")) runs.append(self) def complete(self): return all(value in completed for value in self.value.split(",")) tasks = [CsvLimitedBatchJob(str(i)) for i in range(11)] for task in tasks: self.assertTrue(self.w.add(task)) self.assertTrue(self.w.run()) for task in tasks: self.assertTrue(task.complete()) self.assertEqual(3, len(runs)) def test_fail_max_batch_job(self): class MaxBatchFailJob(luigi.Task): value = luigi.IntParameter(batch_method=max) has_run = False def run(self): self.has_run = True assert False def complete(self): return False tasks = [MaxBatchFailJob(i) for i in range(10)] for task in tasks: self.assertTrue(self.w.add(task)) self.assertFalse(self.w.run()) for task in tasks: # only task number 9 should run self.assertFalse(task.has_run and task.value < 9) self.assertEqual({task.task_id for task in tasks}, set(self.sch.task_list("FAILED", ""))) def test_gracefully_handle_batch_method_failure(self): class BadBatchMethodTask(DummyTask): priority = 10 batch_int_param = luigi.IntParameter(batch_method=int.__add__) # should be sum bad_tasks = [BadBatchMethodTask(i) for i in range(5)] good_tasks = [DummyTask()] all_tasks = good_tasks + bad_tasks self.assertFalse(any(task.complete() for task in all_tasks)) worker = Worker(scheduler=Scheduler(retry_count=1), keep_alive=True) for task in all_tasks: self.assertTrue(worker.add(task)) self.assertFalse(worker.run()) self.assertFalse(any(task.complete() for task in bad_tasks)) # we only get to run the good task if the bad task failures were handled gracefully self.assertTrue(all(task.complete() for task in good_tasks)) def test_post_error_message_for_failed_batch_methods(self): class BadBatchMethodTask(DummyTask): batch_int_param = luigi.IntParameter(batch_method=int.__add__) # should be sum tasks = [BadBatchMethodTask(1), BadBatchMethodTask(2)] for task in tasks: self.assertTrue(self.w.add(task)) self.assertFalse(self.w.run()) failed_ids = set(self.sch.task_list("FAILED", "")) self.assertEqual({task.task_id for task in tasks}, failed_ids) self.assertTrue(all(self.sch.fetch_error(task_id)["error"] for task_id in failed_ids)) class WorkerKeepAliveTests(LuigiTestCase): def setUp(self): self.sch = Scheduler() super(WorkerKeepAliveTests, self).setUp() def _worker_keep_alive_test(self, first_should_live, second_should_live, task_status=None, **worker_args): worker_args.update( { "scheduler": self.sch, "worker_processes": 0, "wait_interval": 0.01, "wait_jitter": 0.0, } ) w1 = Worker(worker_id="w1", **worker_args) w2 = Worker(worker_id="w2", **worker_args) with w1 as worker1, w2 as worker2: worker1.add(DummyTask()) t1 = threading.Thread(target=worker1.run) t1.start() worker2.add(DummyTask()) t2 = threading.Thread(target=worker2.run) t2.start() if task_status: self.sch.add_task(worker="DummyWorker", task_id=DummyTask().task_id, status=task_status) # allow workers to run their get work loops a few times time.sleep(0.1) try: self.assertEqual(first_should_live, t1.is_alive()) self.assertEqual(second_should_live, t2.is_alive()) finally: # mark the task done so the worker threads will die self.sch.add_task(worker="DummyWorker", task_id=DummyTask().task_id, status="DONE") t1.join() t2.join() def test_no_keep_alive(self): self._worker_keep_alive_test( first_should_live=False, second_should_live=False, ) def test_keep_alive(self): self._worker_keep_alive_test( first_should_live=True, second_should_live=True, keep_alive=True, ) def test_keep_alive_count_uniques(self): self._worker_keep_alive_test( first_should_live=False, second_should_live=False, keep_alive=True, count_uniques=True, ) def test_keep_alive_count_last_scheduled(self): self._worker_keep_alive_test( first_should_live=False, second_should_live=True, keep_alive=True, count_last_scheduled=True, ) def test_keep_alive_through_failure(self): self._worker_keep_alive_test( first_should_live=True, second_should_live=True, keep_alive=True, task_status="FAILED", ) def test_do_not_keep_alive_through_disable(self): self._worker_keep_alive_test( first_should_live=False, second_should_live=False, keep_alive=True, task_status="DISABLED", ) class WorkerInterruptedTest(unittest.TestCase): def setUp(self): self.sch = Scheduler(retry_delay=100, remove_delay=1000, worker_disconnect_delay=10) requiring_sigusr = unittest.skipUnless(hasattr(signal, "SIGUSR1"), "signal.SIGUSR1 not found on this system") def _test_stop_getting_new_work(self, worker): d = DummyTask() with worker: worker.add(d) # For assistant its ok that other tasks add it self.assertFalse(d.complete()) worker.handle_interrupt(signal.SIGUSR1, None) worker.run() self.assertFalse(d.complete()) @requiring_sigusr def test_stop_getting_new_work(self): self._test_stop_getting_new_work(Worker(scheduler=self.sch)) @requiring_sigusr def test_stop_getting_new_work_assistant(self): self._test_stop_getting_new_work(Worker(scheduler=self.sch, keep_alive=False, assistant=True)) @requiring_sigusr def test_stop_getting_new_work_assistant_keep_alive(self): self._test_stop_getting_new_work(Worker(scheduler=self.sch, keep_alive=True, assistant=True)) def test_existence_of_disabling_option(self): # any code equivalent of `os.kill(os.getpid(), signal.SIGUSR1)` # seem to give some sort of a "InvocationError" Worker(no_install_shutdown_handler=True) @with_config({"worker": {"no_install_shutdown_handler": "True"}}) def test_can_run_luigi_in_thread(self): class A(DummyTask): pass task = A() # Note that ``signal.signal(signal.SIGUSR1, fn)`` can only be called in the main thread. # So if we do not disable the shutdown handler, this would fail. t = threading.Thread(target=lambda: luigi.build([task], local_scheduler=True)) t.start() t.join() self.assertTrue(task.complete()) class WorkerDisabledTest(LuigiTestCase): def make_sch(self): return Scheduler(retry_delay=100, remove_delay=1000, worker_disconnect_delay=10) def _test_stop_getting_new_work_build(self, sch, worker): """ I got motivated to create this test case when I saw that the execution_summary crashed after my first attempted solution. """ class KillWorkerTask(luigi.Task): did_actually_run = False def run(self): sch.disable_worker("my_worker_id") KillWorkerTask.did_actually_run = True class Factory: def create_local_scheduler(self, *args, **kwargs): return sch def create_worker(self, *args, **kwargs): return worker luigi.build([KillWorkerTask()], worker_scheduler_factory=Factory(), local_scheduler=True) self.assertTrue(KillWorkerTask.did_actually_run) def _test_stop_getting_new_work_manual(self, sch, worker): d = DummyTask() with worker: worker.add(d) # For assistant its ok that other tasks add it self.assertFalse(d.complete()) sch.disable_worker("my_worker_id") worker.run() # Note: Test could fail by hanging on this line self.assertFalse(d.complete()) def _test_stop_getting_new_work(self, **worker_kwargs): worker_kwargs["worker_id"] = "my_worker_id" sch = self.make_sch() worker_kwargs["scheduler"] = sch self._test_stop_getting_new_work_manual(sch, Worker(**worker_kwargs)) sch = self.make_sch() worker_kwargs["scheduler"] = sch self._test_stop_getting_new_work_build(sch, Worker(**worker_kwargs)) def test_stop_getting_new_work_keep_alive(self): self._test_stop_getting_new_work(keep_alive=True, assistant=False) def test_stop_getting_new_work_assistant(self): self._test_stop_getting_new_work(keep_alive=False, assistant=True) def test_stop_getting_new_work_assistant_keep_alive(self): self._test_stop_getting_new_work(keep_alive=True, assistant=True) class DynamicDependenciesTest(LuigiTestCase): n_workers = 1 timeout = float("inf") def setUp(self): self.p = tempfile.mkdtemp() def tearDown(self): shutil.rmtree(self.p) def test_dynamic_dependencies(self, use_banana_task=False): t0 = time.time() t = DynamicRequires(p=self.p, use_banana_task=use_banana_task) luigi.build([t], local_scheduler=True, workers=self.n_workers) self.assertTrue(t.complete()) # loop through output and verify with t.output().open("r") as f: for i in range(7): self.assertEqual(f.readline().strip(), "%d: Done!" % i) self.assertTrue(time.time() - t0 < self.timeout) def test_dynamic_dependencies_with_namespace(self): self.test_dynamic_dependencies(use_banana_task=True) def test_dynamic_dependencies_other_module(self): t = DynamicRequiresOtherModule(p=self.p) luigi.build([t], local_scheduler=True, workers=self.n_workers) self.assertTrue(t.complete()) def test_wrapped_dynamic_requirements(self): t = DynamicRequiresWrapped(p=self.p) luigi.build([t], local_scheduler=True, workers=1) self.assertTrue(t.complete()) self.assertTrue(getattr(t, "_custom_complete_called", False)) self.assertTrue(getattr(t, "_custom_complete_result", False)) class DynamicDependenciesWithMultipleWorkersTest(DynamicDependenciesTest): n_workers = 100 timeout = 10.0 # We run 7 tasks that take 0.5s each so it should take less than 3.5s class WorkerPingThreadTests(unittest.TestCase): def test_ping_retry(self): """Worker ping fails once. Ping continues to try to connect to scheduler Kind of ugly since it uses actual timing with sleep to test the thread """ sch = Scheduler( retry_delay=100, remove_delay=1000, worker_disconnect_delay=10, ) self._total_pings = 0 # class var so it can be accessed from fail_ping def fail_ping(worker): # this will be called from within keep-alive thread... self._total_pings += 1 raise Exception("Some random exception") sch.ping = fail_ping with Worker( scheduler=sch, worker_id="foo", ping_interval=0.01, # very short between pings to make test fast ): # let the keep-alive thread run for a bit... time.sleep(0.1) # yes, this is ugly but it's exactly what we need to test self.assertTrue(self._total_pings > 1, msg="Didn't retry pings (%d pings performed)" % (self._total_pings,)) def test_ping_thread_shutdown(self): with Worker(ping_interval=0.01) as w: self.assertTrue(w._keep_alive_thread.is_alive()) self.assertFalse(w._keep_alive_thread.is_alive()) def email_patch(test_func, email_config=None): EMAIL_CONFIG = {"email": {"receiver": "not-a-real-email-address-for-test-only", "force_send": "true"}} if email_config is not None: EMAIL_CONFIG.update(email_config) emails = [] def mock_send_email(sender, recipients, msg): emails.append(msg) @with_config(EMAIL_CONFIG) @functools.wraps(test_func) @mock.patch("smtplib.SMTP") def run_test(self, smtp): smtp().sendmail.side_effect = mock_send_email test_func(self, emails) return run_test def custom_email_patch(config): return functools.partial(email_patch, email_config=config) class WorkerEmailTest(LuigiTestCase): def run(self, result=None): super(WorkerEmailTest, self).setUp() sch = Scheduler(retry_delay=100, remove_delay=1000, worker_disconnect_delay=10) with Worker(scheduler=sch, worker_id="foo") as self.worker: super(WorkerEmailTest, self).run(result) @email_patch def test_connection_error(self, emails): sch = RemoteScheduler("http://tld.invalid:1337", connect_timeout=1) sch._rpc_retry_wait = 1 # shorten wait time to speed up tests class A(DummyTask): pass a = A() self.assertEqual(emails, []) with Worker(scheduler=sch) as worker: try: worker.add(a) except RPCError as e: self.assertTrue(str(e).find("Errors (3 attempts)") != -1) self.assertNotEqual(emails, []) self.assertTrue(emails[0].find("Luigi: Framework error while scheduling %s" % (a,)) != -1) else: self.fail() @email_patch def test_complete_error(self, emails): class A(DummyTask): def complete(self): raise Exception("b0rk") a = A() self.assertEqual(emails, []) self.worker.add(a) self.assertTrue(emails[0].find("Luigi: %s failed scheduling" % (a,)) != -1) self.worker.run() self.assertTrue(emails[0].find("Luigi: %s failed scheduling" % (a,)) != -1) self.assertFalse(a.has_run) @with_config({"batch_email": {"email_interval": "0"}, "worker": {"send_failure_email": "False"}}) @email_patch def test_complete_error_email_batch(self, emails): class A(DummyTask): def complete(self): raise Exception("b0rk") scheduler = Scheduler(batch_emails=True) worker = Worker(scheduler) a = A() self.assertEqual(emails, []) worker.add(a) self.assertEqual(emails, []) worker.run() self.assertEqual(emails, []) self.assertFalse(a.has_run) scheduler.prune() self.assertTrue("1 scheduling failure" in emails[0]) @with_config({"batch_email": {"email_interval": "0"}, "worker": {"send_failure_email": "False"}}) @email_patch def test_complete_error_email_batch_to_owner(self, emails): class A(DummyTask): owner_email = "a_owner@test.com" def complete(self): raise Exception("b0rk") scheduler = Scheduler(batch_emails=True) worker = Worker(scheduler) a = A() self.assertEqual(emails, []) worker.add(a) self.assertEqual(emails, []) worker.run() self.assertEqual(emails, []) self.assertFalse(a.has_run) scheduler.prune() self.assertTrue(any("1 scheduling failure" in email and "a_owner@test.com" in email for email in emails)) @email_patch def test_announce_scheduling_failure_unexpected_error(self, emails): class A(DummyTask): owner_email = "a_owner@test.com" def complete(self): pass scheduler = Scheduler(batch_emails=True) worker = Worker(scheduler) a = A() with mock.patch.object(worker._scheduler, "announce_scheduling_failure", side_effect=Exception("Unexpected")), self.assertRaises(Exception): worker.add(a) self.assertTrue(len(emails) == 2) # One for `complete` error, one for exception in announcing. self.assertTrue("Luigi: Framework error while scheduling" in emails[1]) self.assertTrue("a_owner@test.com" in emails[1]) @email_patch def test_requires_error(self, emails): class A(DummyTask): def requires(self): raise Exception("b0rk") a = A() self.assertEqual(emails, []) self.worker.add(a) self.assertTrue(emails[0].find("Luigi: %s failed scheduling" % (a,)) != -1) self.worker.run() self.assertFalse(a.has_run) @with_config({"batch_email": {"email_interval": "0"}, "worker": {"send_failure_email": "False"}}) @email_patch def test_requires_error_email_batch(self, emails): class A(DummyTask): def requires(self): raise Exception("b0rk") scheduler = Scheduler(batch_emails=True) worker = Worker(scheduler) a = A() self.assertEqual(emails, []) worker.add(a) self.assertEqual(emails, []) worker.run() self.assertFalse(a.has_run) scheduler.prune() self.assertTrue("1 scheduling failure" in emails[0]) @email_patch def test_complete_return_value(self, emails): class A(DummyTask): def complete(self): pass # no return value should be an error a = A() self.assertEqual(emails, []) self.worker.add(a) self.assertTrue(emails[0].find("Luigi: %s failed scheduling" % (a,)) != -1) self.worker.run() self.assertTrue(emails[0].find("Luigi: %s failed scheduling" % (a,)) != -1) self.assertFalse(a.has_run) @with_config({"batch_email": {"email_interval": "0"}, "worker": {"send_failure_email": "False"}}) @email_patch def test_complete_return_value_email_batch(self, emails): class A(DummyTask): def complete(self): pass # no return value should be an error scheduler = Scheduler(batch_emails=True) worker = Worker(scheduler) a = A() self.assertEqual(emails, []) worker.add(a) self.assertEqual(emails, []) self.worker.run() self.assertEqual(emails, []) self.assertFalse(a.has_run) scheduler.prune() self.assertTrue("1 scheduling failure" in emails[0]) @email_patch def test_run_error(self, emails): class A(luigi.Task): def run(self): raise Exception("b0rk") a = A() luigi.build([a], workers=1, local_scheduler=True) self.assertEqual(1, len(emails)) self.assertTrue(emails[0].find("Luigi: %s FAILED" % (a,)) != -1) @email_patch def test_run_error_long_traceback(self, emails): class A(luigi.Task): def run(self): raise Exception("b0rk" * 10500) a = A() luigi.build([a], workers=1, local_scheduler=True) self.assertTrue(len(emails[0]) < 10000) self.assertTrue(emails[0].find("Traceback exceeds max length and has been truncated")) @with_config({"batch_email": {"email_interval": "0"}, "worker": {"send_failure_email": "False"}}) @email_patch def test_run_error_email_batch(self, emails): class A(luigi.Task): owner_email = ["a@test.com", "b@test.com"] def run(self): raise Exception("b0rk") scheduler = Scheduler(batch_emails=True) worker = Worker(scheduler) worker.add(A()) worker.run() scheduler.prune() self.assertEqual(3, len(emails)) self.assertTrue(any("a@test.com" in email for email in emails)) self.assertTrue(any("b@test.com" in email for email in emails)) @with_config({"batch_email": {"email_interval": "0"}, "worker": {"send_failure_email": "False"}}) @email_patch def test_run_error_batch_email_string(self, emails): class A(luigi.Task): owner_email = "a@test.com" def run(self): raise Exception("b0rk") scheduler = Scheduler(batch_emails=True) worker = Worker(scheduler) worker.add(A()) worker.run() scheduler.prune() self.assertEqual(2, len(emails)) self.assertTrue(any("a@test.com" in email for email in emails)) @with_config({"worker": {"send_failure_email": "False"}}) @email_patch def test_run_error_no_email(self, emails): class A(luigi.Task): def run(self): raise Exception("b0rk") luigi.build([A()], workers=1, local_scheduler=True) self.assertFalse(emails) @staticmethod def read_email(email_msg): subject_obj, body_obj = email.parser.Parser().parsestr(email_msg).walk() return str(subject_obj["Subject"]), str(body_obj.get_payload(decode=True)) @email_patch def test_task_process_dies_with_email(self, emails): a = SendSignalTask(signal.SIGKILL) luigi.build([a], workers=2, local_scheduler=True) self.assertEqual(1, len(emails)) subject, body = self.read_email(emails[0]) self.assertIn("Luigi: {} FAILED".format(a), subject) self.assertIn("died unexpectedly with exit code -9", body) @with_config({"worker": {"send_failure_email": "False"}}) @email_patch def test_task_process_dies_no_email(self, emails): luigi.build([SendSignalTask(signal.SIGKILL)], workers=2, local_scheduler=True) self.assertEqual([], emails) @email_patch def test_task_times_out(self, emails): class A(luigi.Task): worker_timeout = 0.0001 def run(self): time.sleep(5) a = A() luigi.build([a], workers=2, local_scheduler=True) self.assertEqual(1, len(emails)) subject, body = self.read_email(emails[0]) self.assertIn("Luigi: %s FAILED" % (a,), subject) self.assertIn("timed out after 0.0001 seconds and was terminated.", body) @with_config({"worker": {"send_failure_email": "False"}}) @email_patch def test_task_times_out_no_email(self, emails): class A(luigi.Task): worker_timeout = 0.0001 def run(self): time.sleep(5) luigi.build([A()], workers=2, local_scheduler=True) self.assertEqual([], emails) @with_config(dict(worker=dict(retry_external_tasks="true"))) @email_patch def test_external_task_retries(self, emails): """ Test that we do not send error emails on the failures of external tasks """ class A(luigi.ExternalTask): pass a = A() luigi.build([a], workers=2, local_scheduler=True) self.assertEqual(emails, []) @email_patch def test_no_error(self, emails): class A(DummyTask): pass a = A() self.assertEqual(emails, []) self.worker.add(a) self.assertEqual(emails, []) self.worker.run() self.assertEqual(emails, []) self.assertTrue(a.complete()) @custom_email_patch({"email": {"receiver": "not-a-real-email-address-for-test-only", "format": "none"}}) def test_disable_emails(self, emails): class A(luigi.Task): def complete(self): raise Exception("b0rk") self.worker.add(A()) self.assertEqual(emails, []) class RaiseSystemExit(luigi.Task): def run(self): raise SystemExit("System exit!!") class SendSignalTask(luigi.Task): signal = luigi.IntParameter() def run(self): os.kill(os.getpid(), self.signal) class HangTheWorkerTask(luigi.Task): worker_timeout = luigi.IntParameter(default=None) def run(self): while True: pass def complete(self): return False class MultipleWorkersTest(LuigiTestCase): @unittest.skip("Always skip. There are many intermittent failures") def test_multiple_workers(self): # Test using multiple workers # Also test generating classes dynamically since this may reflect issues with # various platform and how multiprocessing is implemented. If it's using os.fork # under the hood it should be fine, but dynamic classses can't be pickled, so # other implementations of multiprocessing (using spawn etc) may fail class MyDynamicTask(luigi.Task): x = luigi.Parameter() def run(self): time.sleep(0.1) t0 = time.time() luigi.build([MyDynamicTask(i) for i in range(100)], workers=100, local_scheduler=True) self.assertTrue(time.time() < t0 + 5.0) # should ideally take exactly 0.1s, but definitely less than 10.0 def test_zero_workers(self): d = DummyTask() luigi.build([d], workers=0, local_scheduler=True) self.assertFalse(d.complete()) def test_system_exit(self): # This would hang indefinitely before this fix: # https://github.com/spotify/luigi/pull/439 luigi.build([RaiseSystemExit()], workers=2, local_scheduler=True) def test_term_worker(self): luigi.build([SendSignalTask(signal.SIGTERM)], workers=2, local_scheduler=True) def test_kill_worker(self): luigi.build([SendSignalTask(signal.SIGKILL)], workers=2, local_scheduler=True) def test_purge_multiple_workers(self): w = Worker(worker_processes=2, wait_interval=0.01) t1 = SendSignalTask(signal.SIGTERM) t2 = SendSignalTask(signal.SIGKILL) w.add(t1) w.add(t2) w._run_task(t1.task_id) w._run_task(t2.task_id) time.sleep(1.0) w._handle_next_task() w._handle_next_task() w._handle_next_task() def test_stop_worker_kills_subprocesses(self): with Worker(worker_processes=2) as w: hung_task = HangTheWorkerTask() w.add(hung_task) w._run_task(hung_task.task_id) pids = [p.pid for p in w._running_tasks.values()] self.assertEqual(1, len(pids)) pid = pids[0] def is_running(): return pid in {p.pid for p in psutil.Process().children()} self.assertTrue(is_running()) self.assertFalse(is_running()) @mock.patch("luigi.worker.time") def test_no_process_leak_from_repeatedly_running_same_task(self, worker_time): with Worker(worker_processes=2) as w: hung_task = HangTheWorkerTask() w.add(hung_task) w._run_task(hung_task.task_id) children = set(psutil.Process().children()) # repeatedly try to run the same task id for _ in range(10): worker_time.sleep.reset_mock() w._run_task(hung_task.task_id) # should sleep after each attempt worker_time.sleep.assert_called_once_with(mock.ANY) # only one process should be running self.assertEqual(children, set(psutil.Process().children())) def test_time_out_hung_worker(self): luigi.build([HangTheWorkerTask(0.1)], workers=2, local_scheduler=True) def test_time_out_hung_single_worker(self): luigi.build([HangTheWorkerTask(0.1)], workers=1, local_scheduler=True) @skipOnTravisAndGithubActions("https://travis-ci.org/spotify/luigi/jobs/72953986") @mock.patch("luigi.worker.time") def test_purge_hung_worker_default_timeout_time(self, mock_time): w = Worker(worker_processes=2, wait_interval=0.01, timeout=5) mock_time.time.return_value = 0 task = HangTheWorkerTask() w.add(task) w._run_task(task.task_id) mock_time.time.return_value = 5 w._handle_next_task() self.assertEqual(1, len(w._running_tasks)) mock_time.time.return_value = 6 w._handle_next_task() self.assertEqual(0, len(w._running_tasks)) @skipOnTravisAndGithubActions("https://travis-ci.org/spotify/luigi/jobs/76645264") @mock.patch("luigi.worker.time") def test_purge_hung_worker_override_timeout_time(self, mock_time): w = Worker(worker_processes=2, wait_interval=0.01, timeout=5) mock_time.time.return_value = 0 task = HangTheWorkerTask(worker_timeout=10) w.add(task) w._run_task(task.task_id) mock_time.time.return_value = 10 w._handle_next_task() self.assertEqual(1, len(w._running_tasks)) mock_time.time.return_value = 11 w._handle_next_task() self.assertEqual(0, len(w._running_tasks)) class Dummy2Task(Task): p = luigi.Parameter() def output(self): return MockTarget(self.p) def run(self): f = self.output().open("w") f.write("test") f.close() class AssistantTest(LuigiTestCase): def run(self, result=None): self.sch = Scheduler(retry_delay=100, remove_delay=1000, worker_disconnect_delay=10) self.assistant = Worker(scheduler=self.sch, worker_id="Y", assistant=True) with Worker(scheduler=self.sch, worker_id="X") as w: self.w = w super(AssistantTest, self).run(result) def test_get_work(self): d = Dummy2Task("123") self.w.add(d) self.assertFalse(d.complete()) self.assistant.run() self.assertTrue(d.complete()) def test_bad_job_type(self): class Dummy3Task(Dummy2Task): task_family = "UnknownTaskFamily" d = Dummy3Task("123") self.w.add(d) self.assertFalse(d.complete()) self.assertFalse(self.assistant.run()) self.assertFalse(d.complete()) self.assertEqual(list(self.sch.task_list("FAILED", "").keys()), [d.task_id]) def test_unimported_job_type(self): MODULE_CONTENTS = b""" import luigi class UnimportedTask(luigi.Task): def complete(self): return False """ reg = luigi.task_register.Register._get_reg() class UnimportedTask(luigi.Task): task_module = None # Set it here, so it's generally settable luigi.task_register.Register._set_reg(reg) task = UnimportedTask() # verify that it can't run the task without the module info necessary to import it self.w.add(task) self.assertFalse(self.assistant.run()) self.assertEqual(list(self.sch.task_list("FAILED", "").keys()), [task.task_id]) # check that it can import with the right module with temporary_unloaded_module(MODULE_CONTENTS) as task.task_module: self.w.add(task) self.assertTrue(self.assistant.run()) self.assertEqual(list(self.sch.task_list("DONE", "").keys()), [task.task_id]) def test_unimported_job_sends_failure_message(self): class NotInAssistantTask(luigi.Task): task_family = "Unknown" task_module = None task = NotInAssistantTask() self.w.add(task) self.assertFalse(self.assistant.run()) self.assertEqual(list(self.sch.task_list("FAILED", "").keys()), [task.task_id]) self.assertTrue(self.sch.fetch_error(task.task_id)["error"]) class ForkBombTask(luigi.Task): depth = luigi.IntParameter() breadth = luigi.IntParameter() p = luigi.Parameter(default=(0,)) # ehm for some weird reason [0] becomes a tuple...? def output(self): return MockTarget(".".join(map(str, self.p))) def run(self): with self.output().open("w") as f: f.write("Done!") def requires(self): if len(self.p) < self.depth: for i in range(self.breadth): yield ForkBombTask(self.depth, self.breadth, self.p + (i,)) class TaskLimitTest(unittest.TestCase): def tearDown(self): MockFileSystem().remove("") @with_config({"worker": {"task_limit": "6"}}) def test_task_limit_exceeded(self): w = Worker() t = ForkBombTask(3, 2) w.add(t) w.run() self.assertFalse(t.complete()) leaf_tasks = [ForkBombTask(3, 2, branch) for branch in [(0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1)]] self.assertEqual( 3, sum(t.complete() for t in leaf_tasks), "should have gracefully completed as much as possible even though the single last leaf didn't get scheduled", ) @with_config({"worker": {"task_limit": "7"}}) def test_task_limit_not_exceeded(self): w = Worker() t = ForkBombTask(3, 2) w.add(t) w.run() self.assertTrue(t.complete()) def test_no_task_limit(self): w = Worker() t = ForkBombTask(4, 2) w.add(t) w.run() self.assertTrue(t.complete()) class WorkerConfigurationTest(unittest.TestCase): def test_asserts_for_worker(self): """ Test that Worker() asserts that it's sanely configured """ Worker(wait_interval=1) # This shouldn't raise self.assertRaises(AssertionError, Worker, wait_interval=0) class WorkerWaitJitterTest(unittest.TestCase): @with_config({"worker": {"wait_jitter": "10.0"}}) @mock.patch("random.uniform") @mock.patch("time.sleep") def test_wait_jitter(self, mock_sleep, mock_random): """verify configured jitter amount""" mock_random.return_value = 1.0 w = Worker() x = w._sleeper() next(x) mock_random.assert_called_with(0, 10.0) mock_sleep.assert_called_with(2.0) mock_random.return_value = 2.0 next(x) mock_random.assert_called_with(0, 10.0) mock_sleep.assert_called_with(3.0) @mock.patch("random.uniform") @mock.patch("time.sleep") def test_wait_jitter_default(self, mock_sleep, mock_random): """verify default jitter is as expected""" mock_random.return_value = 1.0 w = Worker() x = w._sleeper() next(x) mock_random.assert_called_with(0, 5.0) mock_sleep.assert_called_with(2.0) mock_random.return_value = 3.3 next(x) mock_random.assert_called_with(0, 5.0) mock_sleep.assert_called_with(4.3) class KeyboardInterruptBehaviorTest(LuigiTestCase): def test_propagation_when_executing(self): """ Ensure that keyboard interrupts causes luigi to quit when you are executing tasks. TODO: Add a test that tests the multiprocessing (--worker >1) case """ class KeyboardInterruptTask(luigi.Task): def run(self): raise KeyboardInterrupt() cmd = "KeyboardInterruptTask --local-scheduler --no-lock".split(" ") self.assertRaises(KeyboardInterrupt, luigi_run, cmd) def test_propagation_when_scheduling(self): """ Test that KeyboardInterrupt causes luigi to quit while scheduling. """ class KeyboardInterruptTask(luigi.Task): def complete(self): raise KeyboardInterrupt() class ExternalKeyboardInterruptTask(luigi.ExternalTask): def complete(self): raise KeyboardInterrupt() self.assertRaises(KeyboardInterrupt, luigi_run, ["KeyboardInterruptTask", "--local-scheduler", "--no-lock"]) self.assertRaises(KeyboardInterrupt, luigi_run, ["ExternalKeyboardInterruptTask", "--local-scheduler", "--no-lock"]) class WorkerPurgeEventHandlerTest(unittest.TestCase): @mock.patch("luigi.worker.ContextManagedTaskProcess") def test_process_killed_handler(self, task_proc): result = [] @HangTheWorkerTask.event_handler(Event.PROCESS_FAILURE) def store_task(t, error_msg): self.assertTrue(error_msg) result.append(t) w = Worker() task = HangTheWorkerTask() task_process = mock.MagicMock(is_alive=lambda: False, exitcode=-14, task=task) task_proc.return_value = task_process w.add(task) w._run_task(task.task_id) w._handle_next_task() self.assertEqual(result, [task]) @mock.patch("luigi.worker.time") def test_timeout_handler(self, mock_time): result = [] @HangTheWorkerTask.event_handler(Event.TIMEOUT) def store_task(t, error_msg): self.assertTrue(error_msg) result.append(t) w = Worker(worker_processes=2, wait_interval=0.01, timeout=5) mock_time.time.return_value = 0 task = HangTheWorkerTask(worker_timeout=1) w.add(task) w._run_task(task.task_id) mock_time.time.return_value = 3 w._handle_next_task() self.assertEqual(result, [task]) @mock.patch("luigi.worker.time") def test_timeout_handler_single_worker(self, mock_time): result = [] @HangTheWorkerTask.event_handler(Event.TIMEOUT) def store_task(t, error_msg): self.assertTrue(error_msg) result.append(t) w = Worker(wait_interval=0.01, timeout=5) mock_time.time.return_value = 0 task = HangTheWorkerTask(worker_timeout=1) w.add(task) w._run_task(task.task_id) mock_time.time.return_value = 3 w._handle_next_task() self.assertEqual(result, [task]) class PerTaskRetryPolicyBehaviorTest(LuigiTestCase): def setUp(self): super(PerTaskRetryPolicyBehaviorTest, self).setUp() self.per_task_retry_count = 3 self.default_retry_count = 1 self.sch = Scheduler(retry_delay=0.1, retry_count=self.default_retry_count, prune_on_get_work=True) def test_with_all_disabled_with_single_worker(self): """ With this test, a case which has a task (TestWrapperTask), requires two another tasks (TestErrorTask1,TestErrorTask1) which both is failed, is tested. Task TestErrorTask1 has default retry_count which is 1, but Task TestErrorTask2 has retry_count at task level as 2. This test is running on single worker """ class TestErrorTask1(DummyErrorTask): pass e1 = TestErrorTask1() class TestErrorTask2(DummyErrorTask): retry_count = self.per_task_retry_count e2 = TestErrorTask2() class TestWrapperTask(luigi.WrapperTask): def requires(self): return [e2, e1] wt = TestWrapperTask() with Worker(scheduler=self.sch, worker_id="X", keep_alive=True, wait_interval=0.1, wait_jitter=0.05) as w1: self.assertTrue(w1.add(wt)) self.assertFalse(w1.run()) self.assertEqual([wt.task_id], list(self.sch.task_list("PENDING", "UPSTREAM_DISABLED").keys())) self.assertEqual(sorted([e1.task_id, e2.task_id]), sorted(self.sch.task_list("DISABLED", "").keys())) self.assertEqual(0, self.sch._state.get_task(wt.task_id).num_failures()) self.assertEqual(self.per_task_retry_count, self.sch._state.get_task(e2.task_id).num_failures()) self.assertEqual(self.default_retry_count, self.sch._state.get_task(e1.task_id).num_failures()) def test_with_all_disabled_with_multiple_worker(self): """ With this test, a case which has a task (TestWrapperTask), requires two another tasks (TestErrorTask1,TestErrorTask1) which both is failed, is tested. Task TestErrorTask1 has default retry_count which is 1, but Task TestErrorTask2 has retry_count at task level as 2. This test is running on multiple worker """ class TestErrorTask1(DummyErrorTask): pass e1 = TestErrorTask1() class TestErrorTask2(DummyErrorTask): retry_count = self.per_task_retry_count e2 = TestErrorTask2() class TestWrapperTask(luigi.WrapperTask): def requires(self): return [e2, e1] wt = TestWrapperTask() with Worker(scheduler=self.sch, worker_id="X", keep_alive=True, wait_interval=0.1, wait_jitter=0.05) as w1: with Worker(scheduler=self.sch, worker_id="Y", keep_alive=True, wait_interval=0.1, wait_jitter=0.05) as w2: with Worker(scheduler=self.sch, worker_id="Z", keep_alive=True, wait_interval=0.1, wait_jitter=0.05) as w3: self.assertTrue(w1.add(wt)) self.assertTrue(w2.add(e2)) self.assertTrue(w3.add(e1)) self.assertFalse(w3.run()) self.assertFalse(w2.run()) self.assertTrue(w1.run()) self.assertEqual([wt.task_id], list(self.sch.task_list("PENDING", "UPSTREAM_DISABLED").keys())) self.assertEqual(sorted([e1.task_id, e2.task_id]), sorted(self.sch.task_list("DISABLED", "").keys())) self.assertEqual(0, self.sch._state.get_task(wt.task_id).num_failures()) self.assertEqual(self.per_task_retry_count, self.sch._state.get_task(e2.task_id).num_failures()) self.assertEqual(self.default_retry_count, self.sch._state.get_task(e1.task_id).num_failures()) def test_with_includes_success_with_single_worker(self): """ With this test, a case which has a task (TestWrapperTask), requires one (TestErrorTask1) FAILED and one (TestSuccessTask1) SUCCESS, is tested. Task TestSuccessTask1 will be DONE successfully, but Task TestErrorTask1 will be failed and it has retry_count at task level as 2. This test is running on single worker """ class TestSuccessTask1(DummyTask): pass s1 = TestSuccessTask1() class TestErrorTask1(DummyErrorTask): retry_count = self.per_task_retry_count e1 = TestErrorTask1() class TestWrapperTask(luigi.WrapperTask): def requires(self): return [e1, s1] wt = TestWrapperTask() with Worker(scheduler=self.sch, worker_id="X", keep_alive=True, wait_interval=0.1, wait_jitter=0.05) as w1: self.assertTrue(w1.add(wt)) self.assertFalse(w1.run()) self.assertEqual([wt.task_id], list(self.sch.task_list("PENDING", "UPSTREAM_DISABLED").keys())) self.assertEqual([e1.task_id], list(self.sch.task_list("DISABLED", "").keys())) self.assertEqual([s1.task_id], list(self.sch.task_list("DONE", "").keys())) self.assertEqual(0, self.sch._state.get_task(wt.task_id).num_failures()) self.assertEqual(self.per_task_retry_count, self.sch._state.get_task(e1.task_id).num_failures()) self.assertEqual(0, self.sch._state.get_task(s1.task_id).num_failures()) def test_with_includes_success_with_multiple_worker(self): """ With this test, a case which has a task (TestWrapperTask), requires one (TestErrorTask1) FAILED and one (TestSuccessTask1) SUCCESS, is tested. Task TestSuccessTask1 will be DONE successfully, but Task TestErrorTask1 will be failed and it has retry_count at task level as 2. This test is running on multiple worker """ class TestSuccessTask1(DummyTask): pass s1 = TestSuccessTask1() class TestErrorTask1(DummyErrorTask): retry_count = self.per_task_retry_count e1 = TestErrorTask1() class TestWrapperTask(luigi.WrapperTask): def requires(self): return [e1, s1] wt = TestWrapperTask() with Worker(scheduler=self.sch, worker_id="X", keep_alive=True, wait_interval=0.1, wait_jitter=0.05) as w1: with Worker(scheduler=self.sch, worker_id="Y", keep_alive=True, wait_interval=0.1, wait_jitter=0.05) as w2: with Worker(scheduler=self.sch, worker_id="Z", keep_alive=True, wait_interval=0.1, wait_jitter=0.05) as w3: self.assertTrue(w1.add(wt)) self.assertTrue(w2.add(e1)) self.assertTrue(w3.add(s1)) self.assertTrue(w3.run()) self.assertFalse(w2.run()) self.assertTrue(w1.run()) self.assertEqual([wt.task_id], list(self.sch.task_list("PENDING", "UPSTREAM_DISABLED").keys())) self.assertEqual([e1.task_id], list(self.sch.task_list("DISABLED", "").keys())) self.assertEqual([s1.task_id], list(self.sch.task_list("DONE", "").keys())) self.assertEqual(0, self.sch._state.get_task(wt.task_id).num_failures()) self.assertEqual(self.per_task_retry_count, self.sch._state.get_task(e1.task_id).num_failures()) self.assertEqual(0, self.sch._state.get_task(s1.task_id).num_failures()) def test_with_dynamic_dependencies_with_single_worker(self): """ With this test, a case includes dependency tasks(TestErrorTask1,TestErrorTask2) which both are failed. Task TestErrorTask1 has default retry_count which is 1, but Task TestErrorTask2 has retry_count at task level as 2. This test is running on single worker """ class TestErrorTask1(DummyErrorTask): pass e1 = TestErrorTask1() class TestErrorTask2(DummyErrorTask): retry_count = self.per_task_retry_count e2 = TestErrorTask2() class TestSuccessTask1(DummyTask): pass s1 = TestSuccessTask1() class TestWrapperTask(DummyTask): def requires(self): return [s1] def run(self): super(TestWrapperTask, self).run() yield e2, e1 wt = TestWrapperTask() with Worker(scheduler=self.sch, worker_id="X", keep_alive=True, wait_interval=0.1, wait_jitter=0.05) as w1: self.assertTrue(w1.add(wt)) self.assertFalse(w1.run()) self.assertEqual([wt.task_id], list(self.sch.task_list("PENDING", "UPSTREAM_DISABLED").keys())) self.assertEqual(sorted([e1.task_id, e2.task_id]), sorted(self.sch.task_list("DISABLED", "").keys())) self.assertEqual(0, self.sch._state.get_task(wt.task_id).num_failures()) self.assertEqual(0, self.sch._state.get_task(s1.task_id).num_failures()) self.assertEqual(self.per_task_retry_count, self.sch._state.get_task(e2.task_id).num_failures()) self.assertEqual(self.default_retry_count, self.sch._state.get_task(e1.task_id).num_failures()) def test_with_dynamic_dependencies_with_multiple_workers(self): """ With this test, a case includes dependency tasks(TestErrorTask1,TestErrorTask2) which both are failed. Task TestErrorTask1 has default retry_count which is 1, but Task TestErrorTask2 has retry_count at task level as 2. This test is running on multiple worker """ class TestErrorTask1(DummyErrorTask): pass e1 = TestErrorTask1() class TestErrorTask2(DummyErrorTask): retry_count = self.per_task_retry_count e2 = TestErrorTask2() class TestSuccessTask1(DummyTask): pass s1 = TestSuccessTask1() class TestWrapperTask(DummyTask): def requires(self): return [s1] def run(self): super(TestWrapperTask, self).run() yield e2, e1 wt = TestWrapperTask() with Worker(scheduler=self.sch, worker_id="X", keep_alive=True, wait_interval=0.1, wait_jitter=0.05) as w1: with Worker(scheduler=self.sch, worker_id="Y", keep_alive=True, wait_interval=0.1, wait_jitter=0.05) as w2: self.assertTrue(w1.add(wt)) self.assertTrue(w2.add(s1)) self.assertTrue(w2.run()) self.assertFalse(w1.run()) self.assertEqual([wt.task_id], list(self.sch.task_list("PENDING", "UPSTREAM_DISABLED").keys())) self.assertEqual(sorted([e1.task_id, e2.task_id]), sorted(self.sch.task_list("DISABLED", "").keys())) self.assertEqual(0, self.sch._state.get_task(wt.task_id).num_failures()) self.assertEqual(0, self.sch._state.get_task(s1.task_id).num_failures()) self.assertEqual(self.per_task_retry_count, self.sch._state.get_task(e2.task_id).num_failures()) self.assertEqual(self.default_retry_count, self.sch._state.get_task(e1.task_id).num_failures()) def test_per_task_disable_persist_with_single_worker(self): """ Ensure that `Task.disable_window` impacts the task retrying policy: - with the scheduler retry policy (disable_window=3), task fails twice and gets disabled - with the task retry policy (disable_window=0.5) task never gets into the DISABLED state """ class TwoErrorsThenSuccessTask(Task): """ The task is failing two times and then succeeds, waiting 1s before each try """ retry_index = 0 disable_window = None def run(self): time.sleep(1) self.retry_index += 1 if self.retry_index < 3: raise Exception("Retry index is %s for %s" % (self.retry_index, self.task_family)) t = TwoErrorsThenSuccessTask() sch = Scheduler(retry_delay=0.1, retry_count=2, prune_on_get_work=True, disable_window=2) with Worker(scheduler=sch, worker_id="X", keep_alive=True, wait_interval=0.1, wait_jitter=0.05) as w: self.assertTrue(w.add(t)) self.assertFalse(w.run()) self.assertEqual(2, t.retry_index) self.assertEqual([t.task_id], list(sch.task_list("DISABLED").keys())) self.assertEqual(2, sch._state.get_task(t.task_id).num_failures()) t = TwoErrorsThenSuccessTask() t.retry_index = 0 t.disable_window = 0.5 sch = Scheduler(retry_delay=0.1, retry_count=2, prune_on_get_work=True, disable_window=2) with Worker(scheduler=sch, worker_id="X", keep_alive=True, wait_interval=0.1, wait_jitter=0.05) as w: self.assertTrue(w.add(t)) # Worker.run return False even if a task failed first but eventually succeeded. self.assertFalse(w.run()) self.assertEqual(3, t.retry_index) self.assertEqual([t.task_id], list(sch.task_list("DONE").keys())) self.assertEqual(1, len(sch._state.get_task(t.task_id).failures)) ================================================ FILE: test/wrap_test.py ================================================ # -*- coding: utf-8 -*- # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import datetime from helpers import unittest import luigi import luigi.notifications from luigi.mock import MockTarget from luigi.util import inherits luigi.notifications.DEBUG = True class A(luigi.Task): task_namespace = "wrap" # to prevent task name conflict between tests def output(self): return MockTarget("/tmp/a.txt") def run(self): f = self.output().open("w") print("hello, world", file=f) f.close() class B(luigi.Task): date = luigi.DateParameter() def output(self): return MockTarget(self.date.strftime("/tmp/b-%Y-%m-%d.txt")) def run(self): f = self.output().open("w") print("goodbye, space", file=f) f.close() def XMLWrapper(cls): @inherits(cls) class XMLWrapperCls(luigi.Task): def requires(self): return self.clone_parent() def run(self): f = self.input().open("r") g = self.output().open("w") print('<?xml version="1.0" ?>', file=g) for line in f: print("<dummy-xml>" + line.strip() + "</dummy-xml>", file=g) g.close() return XMLWrapperCls class AXML(XMLWrapper(A)): def output(self): return MockTarget("/tmp/a.xml") class BXML(XMLWrapper(B)): def output(self): return MockTarget(self.date.strftime("/tmp/b-%Y-%m-%d.xml")) class WrapperTest(unittest.TestCase): """This test illustrates how a task class can wrap another task class by modifying its behavior. See instance_wrap_test.py for an example of how instances can wrap each other.""" workers = 1 def setUp(self): MockTarget.fs.clear() def test_a(self): luigi.build([AXML()], local_scheduler=True, no_lock=True, workers=self.workers) self.assertEqual(MockTarget.fs.get_data("/tmp/a.xml"), b'<?xml version="1.0" ?>\n<dummy-xml>hello, world</dummy-xml>\n') def test_b(self): luigi.build([BXML(datetime.date(2012, 1, 1))], local_scheduler=True, no_lock=True, workers=self.workers) self.assertEqual(MockTarget.fs.get_data("/tmp/b-2012-01-01.xml"), b'<?xml version="1.0" ?>\n<dummy-xml>goodbye, space</dummy-xml>\n') class WrapperWithMultipleWorkersTest(WrapperTest): workers = 7 ================================================ FILE: tox.ini ================================================ [tox] requires = tox>=4.22 # `dependency_groups` needed tox-uv>=1.19 envlist = py{310,311,312,313}-{cdh,hdp,core,contrib,apache,aws,gcloud,mysql,postgres,unixsocket,azureblob,dropbox}, visualiser, docs, lint, typecheck skipsdist = True [pytest] addopts = --cov=luigi --cov-report=xml -vv --strict-markers --ignore-glob="**/_*" --fulltrace testpaths = test markers = contrib: tests related to luigi/contrib apache: tests related to apache aws: tests related to AWS postgres: tests related to postgresql mysql: tests related to mysql scheduler: tests related to scheduler cdh: tests related to cdh hdp: tests related to hdp gcloud: tests related to GCP unixsocket: tests related to unixsocket dropbox: tests related to dropbox azureblob: tests related to azure unmarked: tests with no explicit markers [testenv] runner = uv-venv-lock-runner allowlist_externals = {toxinidir}/scripts/ci/*.sh dependency_groups = core: common contrib: common apache: common aws: common postgres: test_postgres mysql: common scheduler: common cdh: test_cdh hdp: test_hdp gcloud: test_gcloud unixsocket: test_unixsocket dropbox: test_dropbox azureblob: common passenv = USER, JAVA_HOME, POSTGRES_USER, DATAPROC_TEST_PROJECT_ID, GCS_TEST_PROJECT_ID, GCS_TEST_BUCKET, GOOGLE_APPLICATION_CREDENTIALS, CI, DROPBOX_APP_TOKEN, DOCKERHUB_TOKEN, GITHUB_ACTIONS, OVERRIDE_SKIP_CI_TESTS setenv = LC_ALL = en_US.utf-8 cdh: HADOOP_DISTRO=cdh cdh: HADOOP_HOME={toxinidir}/.tox/hadoop-cdh hdp: HADOOP_DISTRO=hdp hdp: HADOOP_HOME={toxinidir}/.tox/hadoop-hdp LUIGI_CONFIG_PATH={toxinidir}/test/testconfig/luigi.cfg COVERAGE_PROCESS_START={toxinidir}/.coveragerc FULL_COVERAGE=true AWS_DEFAULT_REGION=us-east-1 AWS_ACCESS_KEY_ID=accesskey AWS_SECRET_ACCESS_KEY=secretkey AZURITE_ACCOUNT_NAME=devstoreaccount1 AZURITE_ACCOUNT_KEY=YXp1cml0ZQ== AZURITE_CUSTOM_DOMAIN=localhost:10000 commands = # Setup cdh,hdp: {toxinidir}/scripts/ci/setup_hadoop_env.sh azureblob: {toxinidir}/scripts/ci/install_start_azurite.sh {toxinidir}/scripts/ci {envpython} --version # Test contrib: {envpython} test/runtests.py test/contrib/ -m "contrib or unmarked" {posargs:} apache: {envpython} test/runtests.py -m apache {posargs:} aws: {envpython} test/runtests.py -m aws {posargs:} mysql: {envpython} test/runtests.py -m mysql {posargs:} postgres: {envpython} test/runtests.py -m postgres {posargs:} scheduler: {envpython} test/runtests.py -m scheduler {posargs:} cdh,hdp: {envpython} test/runtests.py -m minicluster {posargs:} gcloud: {envpython} test/runtests.py -m gcloud {posargs:} unixsocket: {envpython} test/runtests.py -m unixsocket {posargs:} dropbox: {envpython} test/runtests.py -m dropbox {posargs:} azureblob: {envpython} test/runtests.py -m azureblob {posargs:} core: {envpython} test/runtests.py --doctest-modules -m "not minicluster and not gcloud and not mysql and not postgres and not unixsocket and not contrib and not apache and not aws and not azureblob and not dropbox" -n auto --dist=loadfile {posargs:} # Teardown azureblob: {toxinidir}/scripts/ci/stop_azurite.sh [testenv:visualiser] runner = uv-venv-lock-runner dependency_groups = visualizer passenv = {[testenv]passenv} setenv = LC_ALL = en_US.utf-8 LUIGI_CONFIG_PATH={toxinidir}/test/testconfig/luigi.cfg TEST_VISUALISER=1 commands = python --version pytest test/visualiser [testenv:lint] dependency_groups = lint commands = ruff check . ruff format --check . [testenv:typecheck] dependency_groups = common commands = mypy luigi/ [testenv:docs] # Python 3.13 required for Sphinx 9.x basepython = py313 # Build documentation using sphinx. # Call this using `tox run -e docs`. dependency_groups = docs setenv = AWS_DEFAULT_REGION=us-east-1 commands = sphinx-build -W -b html -d {envtmpdir}/doctrees doc doc/_build/html