Repository: nils-braun/dask-sql Branch: main Commit: 775b56fb8f99 Files: 251 Total size: 1.3 MB Directory structure: gitextract_q18bbzy0/ ├── .cargo/ │ └── config.toml ├── .coveragerc ├── .dockerignore ├── .github/ │ ├── CODEOWNERS │ ├── ISSUE_TEMPLATE/ │ │ ├── bug_report.md │ │ ├── documentation-request.md │ │ ├── feature_request.md │ │ └── submit-question.md │ ├── dependabot.yml │ └── workflows/ │ ├── conda.yml │ ├── docker.yml │ ├── release.yml │ ├── rust.yml │ ├── style.yml │ ├── test-upstream.yml │ └── test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── Cargo.toml ├── LICENSE.txt ├── MANIFEST.in ├── README.md ├── conftest.py ├── continuous_integration/ │ ├── docker/ │ │ ├── cloud.dockerfile │ │ ├── conda.txt │ │ └── main.dockerfile │ ├── environment-3.10.yaml │ ├── environment-3.11.yaml │ ├── environment-3.12.yaml │ ├── environment-3.9.yaml │ ├── gpuci/ │ │ ├── environment-3.10.yaml │ │ ├── environment-3.11.yaml │ │ └── environment-3.9.yaml │ ├── recipe/ │ │ ├── build.sh │ │ ├── conda_build_config.yaml │ │ ├── meta.yaml │ │ └── run_test.py │ └── scripts/ │ ├── startup_script.py │ └── update-dependencies.sh ├── dask_sql/ │ ├── __init__.py │ ├── _compat.py │ ├── cmd.py │ ├── config.py │ ├── context.py │ ├── datacontainer.py │ ├── input_utils/ │ │ ├── __init__.py │ │ ├── base.py │ │ ├── convert.py │ │ ├── dask.py │ │ ├── hive.py │ │ ├── intake.py │ │ ├── location.py │ │ ├── pandaslike.py │ │ └── sqlalchemy.py │ ├── integrations/ │ │ ├── __init__.py │ │ ├── fugue.py │ │ └── ipython.py │ ├── mappings.py │ ├── physical/ │ │ ├── __init__.py │ │ ├── rel/ │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── convert.py │ │ │ ├── custom/ │ │ │ │ ├── __init__.py │ │ │ │ ├── alter.py │ │ │ │ ├── analyze_table.py │ │ │ │ ├── create_catalog_schema.py │ │ │ │ ├── create_experiment.py │ │ │ │ ├── create_memory_table.py │ │ │ │ ├── create_model.py │ │ │ │ ├── create_table.py │ │ │ │ ├── describe_model.py │ │ │ │ ├── distributeby.py │ │ │ │ ├── drop_model.py │ │ │ │ ├── drop_schema.py │ │ │ │ ├── drop_table.py │ │ │ │ ├── export_model.py │ │ │ │ ├── metrics.py │ │ │ │ ├── predict_model.py │ │ │ │ ├── show_columns.py │ │ │ │ ├── show_models.py │ │ │ │ ├── show_schemas.py │ │ │ │ ├── show_tables.py │ │ │ │ ├── use_schema.py │ │ │ │ └── wrappers.py │ │ │ └── logical/ │ │ │ ├── __init__.py │ │ │ ├── aggregate.py │ │ │ ├── cross_join.py │ │ │ ├── empty.py │ │ │ ├── explain.py │ │ │ ├── filter.py │ │ │ ├── join.py │ │ │ ├── limit.py │ │ │ ├── project.py │ │ │ ├── sample.py │ │ │ ├── sort.py │ │ │ ├── subquery_alias.py │ │ │ ├── table_scan.py │ │ │ ├── union.py │ │ │ ├── values.py │ │ │ └── window.py │ │ ├── rex/ │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── convert.py │ │ │ └── core/ │ │ │ ├── __init__.py │ │ │ ├── alias.py │ │ │ ├── call.py │ │ │ ├── input_ref.py │ │ │ ├── literal.py │ │ │ └── subquery.py │ │ └── utils/ │ │ ├── __init__.py │ │ ├── filter.py │ │ ├── groupby.py │ │ ├── ml_classes.py │ │ ├── sort.py │ │ └── statistics.py │ ├── server/ │ │ ├── __init__.py │ │ ├── app.py │ │ ├── presto_jdbc.py │ │ └── responses.py │ ├── sql-schema.yaml │ ├── sql.yaml │ └── utils.py ├── docs/ │ ├── Makefile │ ├── environment.yml │ ├── make.bat │ ├── requirements-docs.txt │ └── source/ │ ├── api.rst │ ├── best_practices.rst │ ├── cmd.rst │ ├── conf.py │ ├── configuration.rst │ ├── custom.rst │ ├── data_input.rst │ ├── fugue.rst │ ├── how_does_it_work.rst │ ├── index.rst │ ├── installation.rst │ ├── machine_learning.rst │ ├── quickstart.rst │ ├── server.rst │ ├── sql/ │ │ ├── creation.rst │ │ ├── describe.rst │ │ ├── ml.rst │ │ └── select.rst │ └── sql.rst ├── notebooks/ │ ├── Custom Functions.ipynb │ ├── Feature Overview.ipynb │ ├── FugueSQL.ipynb │ └── iris.csv ├── pyproject.toml ├── rustfmt.toml ├── setup.cfg ├── src/ │ ├── dialect.rs │ ├── error.rs │ ├── expression.rs │ ├── lib.rs │ ├── parser.rs │ ├── sql/ │ │ ├── column.rs │ │ ├── exceptions.rs │ │ ├── function.rs │ │ ├── logical/ │ │ │ ├── aggregate.rs │ │ │ ├── alter_schema.rs │ │ │ ├── alter_table.rs │ │ │ ├── analyze_table.rs │ │ │ ├── create_catalog_schema.rs │ │ │ ├── create_experiment.rs │ │ │ ├── create_memory_table.rs │ │ │ ├── create_model.rs │ │ │ ├── create_table.rs │ │ │ ├── describe_model.rs │ │ │ ├── drop_model.rs │ │ │ ├── drop_schema.rs │ │ │ ├── drop_table.rs │ │ │ ├── empty_relation.rs │ │ │ ├── explain.rs │ │ │ ├── export_model.rs │ │ │ ├── filter.rs │ │ │ ├── join.rs │ │ │ ├── limit.rs │ │ │ ├── predict_model.rs │ │ │ ├── projection.rs │ │ │ ├── repartition_by.rs │ │ │ ├── show_columns.rs │ │ │ ├── show_models.rs │ │ │ ├── show_schemas.rs │ │ │ ├── show_tables.rs │ │ │ ├── sort.rs │ │ │ ├── subquery_alias.rs │ │ │ ├── table_scan.rs │ │ │ ├── use_schema.rs │ │ │ └── window.rs │ │ ├── logical.rs │ │ ├── optimizer/ │ │ │ ├── decorrelate_where_exists.rs │ │ │ ├── decorrelate_where_in.rs │ │ │ ├── dynamic_partition_pruning.rs │ │ │ ├── join_reorder.rs │ │ │ └── utils.rs │ │ ├── optimizer.rs │ │ ├── parser_utils.rs │ │ ├── preoptimizer.rs │ │ ├── schema.rs │ │ ├── statement.rs │ │ ├── table.rs │ │ ├── types/ │ │ │ ├── rel_data_type.rs │ │ │ └── rel_data_type_field.rs │ │ └── types.rs │ └── sql.rs └── tests/ ├── __init__.py ├── integration/ │ ├── __init__.py │ ├── fixtures.py │ ├── test_analyze.py │ ├── test_cmd.py │ ├── test_compatibility.py │ ├── test_complex.py │ ├── test_create.py │ ├── test_distributeby.py │ ├── test_explain.py │ ├── test_filter.py │ ├── test_fugue.py │ ├── test_function.py │ ├── test_groupby.py │ ├── test_hive.py │ ├── test_intake.py │ ├── test_jdbc.py │ ├── test_join.py │ ├── test_model.py │ ├── test_over.py │ ├── test_postgres.py │ ├── test_rex.py │ ├── test_sample.py │ ├── test_schema.py │ ├── test_select.py │ ├── test_server.py │ ├── test_show.py │ ├── test_sort.py │ ├── test_sqlite.py │ └── test_union.py ├── unit/ │ ├── __init__.py │ ├── test_call.py │ ├── test_config.py │ ├── test_context.py │ ├── test_datacontainer.py │ ├── test_mapping.py │ ├── test_ml_utils.py │ ├── test_queries.py │ ├── test_statistics.py │ └── test_utils.py └── utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .cargo/config.toml ================================================ [target.x86_64-apple-darwin] rustflags = [ "-C", "link-arg=-undefined", "-C", "link-arg=dynamic_lookup", ] [target.aarch64-apple-darwin] rustflags = [ "-C", "link-arg=-undefined", "-C", "link-arg=dynamic_lookup", ] ================================================ FILE: .coveragerc ================================================ [run] omit = tests/* branch = True [report] # Regexes for lines to exclude from consideration exclude_lines = # Have to re-enable the standard pragma pragma: no cover # Don't complain about missing debug-only code: def __repr__ # Don't complain if tests don't hit defensive assertion code: raise AssertionError raise NotImplementedError # Don't complain if non-runnable code isn't run: if __name__ == .__main__.: ================================================ FILE: .dockerignore ================================================ node_modules .next ================================================ FILE: .github/CODEOWNERS ================================================ # global codeowners * @ayushdg @charlesbluca @galipremsagar # rust codeowners .cargo/ @ayushdg @charlesbluca @galipremsagar @jdye64 src/ @ayushdg @charlesbluca @galipremsagar @jdye64 Cargo.toml @ayushdg @charlesbluca @galipremsagar @jdye64 Cargo.lock @ayushdg @charlesbluca @galipremsagar @jdye64 ================================================ FILE: .github/ISSUE_TEMPLATE/bug_report.md ================================================ --- name: Bug report about: Create a bug report to help us improve dask-sql title: "[BUG]" labels: "bug, needs triage" assignees: '' --- **What happened**: **What you expected to happen**: **Minimal Complete Verifiable Example**: ```python # Put your MCVE code here ``` **Anything else we need to know?**: **Environment**: - dask-sql version: - Python version: - Operating System: - Install method (conda, pip, source): ================================================ FILE: .github/ISSUE_TEMPLATE/documentation-request.md ================================================ --- name: Documentation request about: Report incorrect or needed documentation title: "[DOC]" labels: "documentation" assignees: '' --- ## Report incorrect documentation **Location of incorrect documentation** Provide links and line numbers if applicable. **Describe the problems or issues found in the documentation** A clear and concise description of what you found to be incorrect. **Steps taken to verify documentation is incorrect** List any steps you have taken: **Suggested fix for documentation** Detail proposed changes to fix the documentation if you have any. --- ## Report needed documentation **Report needed documentation** A clear and concise description of what documentation you believe it is needed and why. **Describe the documentation you'd like** A clear and concise description of what you want to happen. **Steps taken to search for needed documentation** List any steps you have taken: ================================================ FILE: .github/ISSUE_TEMPLATE/feature_request.md ================================================ --- name: Feature request about: Suggest an idea for dask-sql title: "[ENH]" labels: "enhancement, needs triage" assignees: '' --- **Is your feature request related to a problem? Please describe.** A clear and concise description of what the problem is. Ex. I wish I could use dask-sql to do [...] **Describe the solution you'd like** A clear and concise description of what you want to happen. **Describe alternatives you've considered** A clear and concise description of any alternative solutions or features you've considered. **Additional context** Add any other context, code examples, or references to existing implementations about the feature request here. ================================================ FILE: .github/ISSUE_TEMPLATE/submit-question.md ================================================ --- name: Submit question about: Ask a general question about dask-sql title: "[QST]" labels: "question" assignees: '' --- **What is your question?** ================================================ FILE: .github/dependabot.yml ================================================ version: 2 updates: - package-ecosystem: "cargo" directory: "/" schedule: interval: "daily" ignore: # arrow and datafusion are bumped manually - dependency-name: "arrow" update-types: ["version-update:semver-major"] - dependency-name: "datafusion" update-types: ["version-update:semver-major"] - dependency-name: "datafusion-*" update-types: ["version-update:semver-major"] - package-ecosystem: "github-actions" directory: "/" schedule: # Check for updates to GitHub Actions every weekday interval: "weekly" ignore: # prefer updating cibuildwheel manually as needed - dependency-name: "pypa/cibuildwheel" ================================================ FILE: .github/workflows/conda.yml ================================================ name: Build conda nightly on: push: branches: - main pull_request: paths: - Cargo.toml - Cargo.lock - pyproject.toml - continuous_integration/recipe/** - .github/workflows/conda.yml schedule: - cron: '0 0 * * 0' # When this workflow is queued, automatically cancel any previous running # or pending jobs from the same branch concurrency: group: conda-${{ github.head_ref }} cancel-in-progress: true # Required shell entrypoint to have properly activated conda environments defaults: run: shell: bash -l {0} jobs: conda: name: "Build conda nightlies (python: ${{ matrix.python }}, arch: ${{ matrix.arch }})" runs-on: ubuntu-latest strategy: fail-fast: false matrix: python: ["3.9", "3.10", "3.11", "3.12"] arch: ["linux-64", "linux-aarch64"] steps: - name: Manage disk space if: matrix.arch == 'linux-aarch64' run: | sudo mkdir -p /opt/empty_dir || true for d in \ /opt/ghc \ /opt/hostedtoolcache \ /usr/lib/jvm \ /usr/local/.ghcup \ /usr/local/lib/android \ /usr/local/share/powershell \ /usr/share/dotnet \ /usr/share/swift \ ; do sudo rsync --stats -a --delete /opt/empty_dir/ $d || true done sudo apt-get purge -y -f firefox \ google-chrome-stable \ microsoft-edge-stable sudo apt-get autoremove -y >& /dev/null sudo apt-get autoclean -y >& /dev/null sudo docker image prune --all --force df -h - name: Create swapfile if: matrix.arch == 'linux-aarch64' run: | sudo fallocate -l 10GiB /swapfile || true sudo chmod 600 /swapfile || true sudo mkswap /swapfile || true sudo swapon /swapfile || true - uses: actions/checkout@v4 with: fetch-depth: 0 - name: Set up Python uses: conda-incubator/setup-miniconda@v2.3.0 with: miniforge-variant: Mambaforge use-mamba: true python-version: "3.9" channel-priority: strict - name: Install dependencies run: | mamba install -c conda-forge "boa<0.17" "conda-build<24.1" conda-verify which python pip list mamba list - name: Build conda packages run: | # suffix for nightly package versions export VERSION_SUFFIX=a`date +%y%m%d` conda mambabuild continuous_integration/recipe \ --python ${{ matrix.python }} \ --variants "{target_platform: [${{ matrix.arch }}]}" \ --error-overlinking \ --no-test \ --no-anaconda-upload \ --output-folder packages - name: Test conda packages if: matrix.arch == 'linux-64' # can only test native platform packages run: | conda mambabuild --test packages/${{ matrix.arch }}/*.tar.bz2 - name: Upload conda packages as artifacts uses: actions/upload-artifact@v3 with: name: "conda nightlies (python - ${{ matrix.python }}, arch - ${{ matrix.arch }})" # need to install all conda channel metadata to properly install locally path: packages/ - name: Upload conda packages to Anaconda if: | github.event_name == 'push' && github.repository == 'dask-contrib/dask-sql' env: ANACONDA_API_TOKEN: ${{ secrets.DASK_CONDA_TOKEN }} run: | # install anaconda for upload mamba install -c conda-forge anaconda-client anaconda upload --label dev packages/${{ matrix.arch }}/*.tar.bz2 ================================================ FILE: .github/workflows/docker.yml ================================================ name: Build Docker image on: release: types: [created] push: branches: - main pull_request: paths: - Cargo.toml - Cargo.lock - pyproject.toml - continuous_integration/docker/** - .github/workflows/docker.yml # When this workflow is queued, automatically cancel any previous running # or pending jobs from the same branch concurrency: group: docker-${{ github.ref }} cancel-in-progress: true jobs: push_to_registry: name: Push Docker image to Docker Hub runs-on: ubuntu-latest env: DOCKER_PUSH: ${{ contains(fromJSON('["push", "release"]'), github.event_name) && github.repository == 'dask-contrib/dask-sql' }} strategy: fail-fast: false matrix: platform: ["linux/amd64", "linux/arm64", "linux/386"] steps: - uses: actions/checkout@v4 - name: Login to DockerHub if: ${{ fromJSON(env.DOCKER_PUSH) }} uses: docker/login-action@v3 with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_PASSWORD }} - name: Docker meta for main image id: docker_meta_main uses: crazy-max/ghaction-docker-meta@v5 with: images: nbraun/dask-sql - name: Build and push main image uses: docker/build-push-action@v5 with: context: . file: ./continuous_integration/docker/main.dockerfile build-args: DOCKER_META_VERSION=${{ steps.docker_meta_main.outputs.version }} platforms: ${{ matrix.platform }} tags: ${{ steps.docker_meta_main.outputs.tags }} labels: ${{ steps.docker_meta_main.outputs.labels }} push: ${{ fromJSON(env.DOCKER_PUSH) }} load: ${{ !fromJSON(env.DOCKER_PUSH) }} - name: Check images run: | df -h docker image ls docker image inspect ${{ steps.docker_meta_main.outputs.tags }} - name: Docker meta for cloud image id: docker_meta_cloud uses: crazy-max/ghaction-docker-meta@v5 with: images: nbraun/dask-sql-cloud - name: Build and push cloud image uses: docker/build-push-action@v5 with: context: . file: ./continuous_integration/docker/cloud.dockerfile build-args: DOCKER_META_VERSION=${{ steps.docker_meta_main.outputs.version }} platforms: ${{ matrix.platform }} tags: ${{ steps.docker_meta_cloud.outputs.tags }} labels: ${{ steps.docker_meta_cloud.outputs.labels }} push: ${{ fromJSON(env.DOCKER_PUSH) }} load: ${{ !fromJSON(env.DOCKER_PUSH) }} ================================================ FILE: .github/workflows/release.yml ================================================ name: Upload Python package on: release: types: [created] pull_request: paths: - .github/workflows/release.yml - dask_sql/__init__.py # When this workflow is queued, automatically cancel any previous running # or pending jobs from the same branch concurrency: group: release-${{ github.head_ref }} cancel-in-progress: true env: upload: ${{ github.event_name == 'release' && github.repository == 'dask-contrib/dask-sql' }} jobs: linux: name: Build and publish wheels for linux ${{ matrix.target }} runs-on: ubuntu-latest strategy: fail-fast: false matrix: target: [x86_64, aarch64] steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v4 with: python-version: '3.10' - name: Build wheels for x86_64 if: matrix.target == 'x86_64' uses: PyO3/maturin-action@v1 with: target: ${{ matrix.target }} args: --release --out dist sccache: 'true' manylinux: '2_17' - name: Build wheels for aarch64 if: matrix.target == 'aarch64' uses: PyO3/maturin-action@v1 with: target: ${{ matrix.target }} args: --release --out dist --zig sccache: 'true' manylinux: '2_17' - name: Check dist files run: | pip install twine twine check dist/* ls -lh dist/ - name: Upload binary wheels uses: actions/upload-artifact@v3 with: name: wheels for linux ${{ matrix.target }} path: dist/* - name: Publish package if: env.upload == 'true' env: TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} run: twine upload dist/* windows: name: Build and publish wheels for windows runs-on: windows-latest steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v4 with: python-version: '3.10' architecture: x64 - name: Build wheels uses: PyO3/maturin-action@v1 with: target: x64 args: --release --out dist sccache: 'true' - name: Check dist files run: | pip install twine twine check dist/* ls dist/ - name: Upload binary wheels uses: actions/upload-artifact@v3 with: name: wheels for windows path: dist/* - name: Publish package if: env.upload == 'true' env: TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} run: twine upload dist/* macos: name: Build and publish wheels for macos ${{ matrix.target }} runs-on: macos-latest strategy: fail-fast: false matrix: target: [x86_64, aarch64] steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v4 with: python-version: '3.10' - name: Build wheels uses: PyO3/maturin-action@v1 with: target: ${{ matrix.target }} args: --release --out dist sccache: 'true' - name: Check dist files run: | pip install twine twine check dist/* ls -lh dist/ - name: Upload binary wheels uses: actions/upload-artifact@v3 with: name: wheels for macos ${{ matrix.target }} path: dist/* - name: Publish package if: env.upload == 'true' env: TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} run: twine upload dist/* sdist: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - name: Build sdist uses: PyO3/maturin-action@v1 with: command: sdist args: --out dist - uses: actions/setup-python@v4 with: python-version: '3.10' - name: Check dist files run: | pip install twine twine check dist/* ls -lh dist/ - name: Publish source distribution if: env.upload == 'true' env: TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} run: twine upload dist/* ================================================ FILE: .github/workflows/rust.yml ================================================ name: Test Rust package on: # always trigger on PR push: branches: - main pull_request: # manual trigger # https://docs.github.com/en/actions/managing-workflow-runs/manually-running-a-workflow workflow_dispatch: env: # Disable full debug symbol generation to speed up CI build and keep memory down # "1" means line tables only, which is useful for panic tracebacks. RUSTFLAGS: "-C debuginfo=1" jobs: detect-ci-trigger: name: Check for upstream trigger phrase runs-on: ubuntu-latest if: github.repository == 'dask-contrib/dask-sql' outputs: triggered: ${{ steps.detect-trigger.outputs.trigger-found }} steps: - uses: actions/checkout@v4 with: fetch-depth: 2 - uses: xarray-contrib/ci-trigger@v1.2 id: detect-trigger with: keyword: "[test-df-upstream]" # Check crate compiles linux-build-lib: name: cargo check needs: [detect-ci-trigger] runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - uses: actions-rs/toolchain@v1 with: toolchain: 1.72 default: true - name: Cache Cargo uses: actions/cache@v3 with: path: /home/runner/.cargo key: cargo-cache - name: Optionally update upstream dependencies if: needs.detect-ci-trigger.outputs.triggered == 'true' run: | bash continuous_integration/scripts/update-dependencies.sh - name: Check workspace in debug mode run: | cargo check - name: Check workspace in release mode run: | cargo check --release # test the crate linux-test: name: cargo test needs: [detect-ci-trigger] runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 with: submodules: true - uses: actions-rs/toolchain@v1 with: toolchain: 1.72 default: true - name: Cache Cargo uses: actions/cache@v3 with: path: /home/runner/.cargo key: cargo-cache - name: Optionally update upstream dependencies if: needs.detect-ci-trigger.outputs.triggered == 'true' run: | bash continuous_integration/scripts/update-dependencies.sh - name: Run tests run: | cargo test ================================================ FILE: .github/workflows/style.yml ================================================ --- name: Python style check on: [pull_request] # When this workflow is queued, automatically cancel any previous running # or pending jobs from the same branch concurrency: group: style-${{ github.head_ref }} cancel-in-progress: true jobs: pre-commit: name: Run pre-commit hooks runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v4 - uses: actions-rs/toolchain@v1 with: toolchain: 1.72 components: clippy default: true - uses: actions-rs/toolchain@v1 with: toolchain: nightly components: rustfmt - uses: pre-commit/action@v3.0.0 ================================================ FILE: .github/workflows/test-upstream.yml ================================================ name: Nightly upstream testing on: schedule: - cron: "0 0 * * *" # Daily “At 00:00” UTC workflow_dispatch: # allows you to trigger the workflow run manually # Required shell entrypoint to have properly activated conda environments defaults: run: shell: bash -l {0} jobs: test-dev: name: "Test upstream dev (${{ matrix.os }}, python: ${{ matrix.python }}, distributed: ${{ matrix.distributed }}, query-planning: ${{ matrix.query-planning }})" runs-on: ${{ matrix.os }} env: CONDA_FILE: continuous_integration/environment-${{ matrix.python }}.yaml DASK_SQL_DISTRIBUTED_TESTS: ${{ matrix.distributed }} DASK_DATAFRAME__QUERY_PLANNING: ${{ matrix.query-planning }} strategy: fail-fast: false matrix: os: [ubuntu-latest, windows-latest, macos-latest] python: ["3.9", "3.10", "3.11", "3.12"] distributed: [false] query-planning: [true] include: # run tests on a distributed client - os: "ubuntu-latest" python: "3.9" distributed: true query-planning: true - os: "ubuntu-latest" python: "3.11" distributed: true query-planning: true # run tests with query planning disabled - os: "ubuntu-latest" python: "3.9" distributed: false query-planning: false - os: "ubuntu-latest" python: "3.11" distributed: false query-planning: false steps: - uses: actions/checkout@v4 with: fetch-depth: 0 # Fetch all history for all branches and tags. - name: Set up Python uses: conda-incubator/setup-miniconda@v2.3.0 with: miniforge-variant: Mambaforge use-mamba: true python-version: ${{ matrix.python }} channel-priority: strict activate-environment: dask-sql environment-file: ${{ env.CONDA_FILE }} - uses: actions-rs/toolchain@v1 with: toolchain: 1.72 default: true - name: Install x86_64-apple-darwin target if: matrix.os == 'macos-latest' run: rustup target add x86_64-apple-darwin - name: Build the Rust DataFusion bindings run: | maturin develop - name: Install hive testing dependencies if: matrix.os == 'ubuntu-latest' run: | docker pull bde2020/hive:2.3.2-postgresql-metastore docker pull bde2020/hive-metastore-postgresql:2.3.0 - name: Install upstream dev Dask run: | mamba install --no-channel-priority dask/label/dev::dask - name: Install pytest-reportlog run: | # TODO: add pytest-reportlog to testing environments if we move over to JSONL output mamba install pytest-reportlog - name: Test with pytest id: run_tests run: | pytest --report-log test-${{ matrix.os }}-py${{ matrix.python }}-results.jsonl --cov-report=xml -n auto tests --dist loadfile - name: Upload pytest results for failure if: | always() && steps.run_tests.outcome != 'skipped' uses: actions/upload-artifact@v3 with: name: test-${{ matrix.os }}-py${{ matrix.python }}-results path: test-${{ matrix.os }}-py${{ matrix.python }}-results.jsonl import-dev: name: "Test importing with bare requirements and upstream dev (query-planning: ${{ matrix.query-planning }})" runs-on: ubuntu-latest strategy: fail-fast: false matrix: query-planning: [true, false] steps: - uses: actions/checkout@v4 - name: Set up Python uses: conda-incubator/setup-miniconda@v2.3.0 with: miniforge-variant: Mambaforge use-mamba: true python-version: "3.9" channel-priority: strict - uses: actions-rs/toolchain@v1 with: toolchain: 1.72 default: true - name: Install dependencies and nothing else run: | pip install -e . -vv which python pip list mamba list - name: Install upstream dev Dask run: | python -m pip install git+https://github.com/dask/dask python -m pip install git+https://github.com/dask/dask-expr python -m pip install git+https://github.com/dask/distributed - name: Try to import dask-sql env: DASK_DATAFRAME_QUERY_PLANNING: ${{ matrix.query-planning }} run: | python -c "import dask_sql; print('ok')" report-failures: name: Open issue for upstream dev failures needs: [test-dev, import-dev] if: | always() && ( needs.test-dev.result == 'failure' || needs.import-dev.result == 'failure' ) && github.repository == 'dask-contrib/dask-sql' runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - uses: actions/download-artifact@v3 - name: Prepare logs & issue label run: | # TODO: remove this if xarray-contrib/issue-from-pytest-log no longer needs a log-path if [ -f test-ubuntu-latest-py3.10-results/test-ubuntu-latest-py3.10-results.jsonl ]; then cp test-ubuntu-latest-py3.10-results/test-ubuntu-latest-py3.10-results.jsonl results.jsonl else touch results.jsonl fi - name: Open or update issue on failure uses: xarray-contrib/issue-from-pytest-log@v1.2.6 with: log-path: results.jsonl issue-title: ⚠️ Upstream CI failed ⚠️ issue-label: upstream ================================================ FILE: .github/workflows/test.yml ================================================ name: Test Python package on: push: branches: - main pull_request: # When this workflow is queued, automatically cancel any previous running # or pending jobs from the same branch concurrency: group: test-${{ github.head_ref }} cancel-in-progress: true # Required shell entrypoint to have properly activated conda environments defaults: run: shell: bash -l {0} jobs: detect-ci-trigger: name: Check for upstream trigger phrase runs-on: ubuntu-latest if: github.repository == 'dask-contrib/dask-sql' outputs: triggered: ${{ steps.detect-trigger.outputs.trigger-found }} steps: - uses: actions/checkout@v4 with: fetch-depth: 2 - uses: xarray-contrib/ci-trigger@v1.2 id: detect-trigger with: keyword: "[test-upstream]" test: name: "Build & Test (${{ matrix.os }}, python: ${{ matrix.python }}, distributed: ${{ matrix.distributed }}, query-planning: ${{ matrix.query-planning }})" needs: [detect-ci-trigger] runs-on: ${{ matrix.os }} env: CONDA_FILE: continuous_integration/environment-${{ matrix.python }}.yaml DASK_SQL_DISTRIBUTED_TESTS: ${{ matrix.distributed }} DASK_DATAFRAME__QUERY_PLANNING: ${{ matrix.query-planning }} strategy: fail-fast: false matrix: os: [ubuntu-latest, windows-latest, macos-latest] python: ["3.9", "3.10", "3.11", "3.12"] distributed: [false] query-planning: [true] include: # run tests on a distributed client - os: "ubuntu-latest" python: "3.9" distributed: true query-planning: true - os: "ubuntu-latest" python: "3.11" distributed: true query-planning: true # run tests with query planning disabled - os: "ubuntu-latest" python: "3.9" distributed: false query-planning: false - os: "ubuntu-latest" python: "3.11" distributed: false query-planning: false steps: - uses: actions/checkout@v4 - name: Set up Python uses: conda-incubator/setup-miniconda@v2.3.0 with: miniforge-variant: Mambaforge use-mamba: true python-version: ${{ matrix.python }} channel-priority: strict activate-environment: dask-sql environment-file: ${{ env.CONDA_FILE }} run-post: ${{ matrix.os != 'windows-latest' && 'true' || 'false' }} - uses: actions-rs/toolchain@v1 with: toolchain: 1.72 default: true - name: Install x86_64-apple-darwin target if: matrix.os == 'macos-latest' run: rustup target add x86_64-apple-darwin - name: Build the Rust DataFusion bindings run: | maturin develop - name: Install hive testing dependencies if: matrix.os == 'ubuntu-latest' run: | docker pull bde2020/hive:2.3.2-postgresql-metastore docker pull bde2020/hive-metastore-postgresql:2.3.0 - name: Optionally install upstream dev Dask if: needs.detect-ci-trigger.outputs.triggered == 'true' run: | mamba install --no-channel-priority dask/label/dev::dask - name: Test with pytest run: | pytest --junitxml=junit/test-results.xml --cov-report=xml -n auto tests --dist loadfile - name: Upload pytest test results if: always() uses: actions/upload-artifact@v3 with: name: pytest-results path: junit/test-results.xml - name: Upload coverage to Codecov if: github.repository == 'dask-contrib/dask-sql' uses: codecov/codecov-action@v3 import: name: "Test importing with bare requirements (query-planning: ${{ matrix.query-planning }})" needs: [detect-ci-trigger] runs-on: ubuntu-latest strategy: fail-fast: false matrix: query-planning: [true, false] steps: - uses: actions/checkout@v4 - name: Set up Python uses: conda-incubator/setup-miniconda@v2.3.0 with: miniforge-variant: Mambaforge use-mamba: true python-version: "3.9" channel-priority: strict - uses: actions-rs/toolchain@v1 with: toolchain: 1.72 default: true - name: Install dependencies and nothing else run: | pip install -e . -vv which python pip list mamba list - name: Optionally install upstream dev Dask if: needs.detect-ci-trigger.outputs.triggered == 'true' run: | python -m pip install git+https://github.com/dask/dask python -m pip install git+https://github.com/dask/dask-expr python -m pip install git+https://github.com/dask/distributed - name: Try to import dask-sql env: DASK_DATAFRAME_QUERY_PLANNING: ${{ matrix.query-planning }} run: | python -c "import dask_sql; print('ok')" ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST *.so # Unit test / coverage reports htmlcov/ .coverage .coverage.* .cache coverage.xml *.cover .pytest_cache/ .hypothesis/ .pytest-html # Jupyter Notebook .ipynb_checkpoints # environments conda-env env venv # IDE .idea .vscode *.swp # project specific dask-worker-space/ node_modules/ docs/source/_build/ tests/unit/queries tests/unit/data target/* packages/* # Ignore development specific local testing files dev_tests dev-tests ================================================ FILE: .pre-commit-config.yaml ================================================ repos: - repo: https://github.com/psf/black rev: 22.10.0 hooks: - id: black language_version: python3 - repo: https://github.com/PyCQA/flake8 rev: 5.0.4 hooks: - id: flake8 language_version: python3 - repo: https://github.com/pycqa/isort rev: 5.12.0 hooks: - id: isort args: - "--profile" - "black" - repo: https://github.com/doublify/pre-commit-rust rev: v1.0 hooks: - id: cargo-check args: ['--manifest-path', './Cargo.toml', '--verbose', '--'] - id: clippy args: ['--manifest-path', './Cargo.toml', '--verbose', '--', '-D', 'warnings'] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.2.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer - id: check-yaml exclude: ^continuous_integration/recipe/ - id: check-added-large-files - repo: local hooks: - id: cargo-fmt name: cargo fmt description: Format files with cargo fmt. entry: cargo +nightly fmt language: system types: [rust] args: ['--manifest-path', './Cargo.toml', '--verbose', '--'] ================================================ FILE: .readthedocs.yaml ================================================ # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details version: 2 build: os: ubuntu-20.04 tools: python: "mambaforge-4.10" sphinx: configuration: docs/source/conf.py conda: environment: docs/environment.yml python: install: - method: pip path: . ================================================ FILE: CODE_OF_CONDUCT.md ================================================ # Contributor Covenant Code of Conduct ## Our Pledge In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. ## Our Standards Examples of behavior that contributes to creating a positive environment include: * Using welcoming and inclusive language * Being respectful of differing viewpoints and experiences * Gracefully accepting constructive criticism * Focusing on what is best for the community * Showing empathy towards other community members Examples of unacceptable behavior by participants include: * The use of sexualized language or imagery and unwelcome sexual attention or advances * Trolling, insulting/derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or electronic address, without explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Our Responsibilities Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. ## Scope This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at nilslennartbraun@gmail.com. All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html [homepage]: https://www.contributor-covenant.org For answers to common questions about this code of conduct, see https://www.contributor-covenant.org/faq ================================================ FILE: CONTRIBUTING.md ================================================ # Contributing to Dask-SQL ## Environment Setup The environment used for development and CI consists of: - a system installation of [`rustup`](https://rustup.rs/) with: - the latest stable toolchain - the latest nightly `rustfmt` - a [conda](https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html) environment containing all required Python packages Once `rustup` is installed, ensure that the latest stable toolchain and nightly `rustfmt` are available by running ``` rustup toolchain install nightly -c rustfmt --profile minimal rustup update ``` To initialize and activate the conda environment for a given Python version: ``` conda env create -f dask-sql/continuous_integration/environment-{$PYTHON_VER}.yaml conda activate dask-sql ``` ## Rust Developers Guide Dask-SQL utilizes [Apache Arrow Datafusion](https://github.com/apache/arrow-datafusion) for parsing, planning, and optimizing SQL queries. DataFusion is written in Rust and therefore requires some Rust experience to be productive. Luckily, there are tons of great Rust learning resources on the internet. We have listed some of our favorite ones [here](#rust-learning-resources) ### Apache Arrow DataFusion The Dask-SQL Rust codebase makes heavy use [Apache Arrow DataFusion](https://github.com/apache/arrow-datafusion). Contributors should familiarize themselves with the [codebase](https://github.com/apache/arrow-datafusion) and [documentation](https://docs.rs/datafusion/latest/datafusion/). #### Purpose DataFusion provides Dask-SQL with key functionality. - Parsing SQL query strings into a `LogicalPlan` datastructure - Future integration points with [substrait.io](https://substrait.io/) - An optimization framework used as the baseline for creating custom highly efficient `LogicalPlan`s specific to Dask. ### Building Building the Dask-SQL Rust codebase is a straightforward process. If you create and activate the Dask-SQL Conda environment the Rust compiler and all necessary components will be installed for you during that process and therefore requires no further manual setup. `maturin` is used by Dask-SQL for building and bundling the resulting Rust binaries. This helps make building and installing the Rust binaries feel much more like a native Python workflow. More details about the building setup can be found in [pyproject.toml](pyproject.toml) and [Cargo.toml](Cargo.toml) Note that while `maturin` is used by CI and should be used during your development cycle, if the need arises to do something more specific that is not yet supported by `maturin` you can opt to use `cargo` directly from the command line. #### Building with Python Building Dask-SQL is straightforward with Python. To build run ```pip install .```. This will build both the Rust and Python codebase and install it into your locally activated conda environment; note that if your Rust dependencies have been updated, this command must be rerun to rebuild the Rust codebase. #### DataFusion Modules DataFusion is broken down into a few modules. We consume those modules in our [Cargo.toml](Cargo.toml). The modules that we use currently are - `datafusion-common` - Datastructures and core logic - `datafusion-expr` - Expression based logic and operators - `datafusion-sql` - SQL components such as parsing and planning - `datafusion-optimizer` - Optimization logic and datastructures for modifying current plans into more efficient ones. #### Retrieving Upstream Dependencies During development you might find yourself needing some upstream DataFusion changes not present in the projects current version. Luckily this can easily be achieved by updating [Cargo.toml](Cargo.toml) and changing the `rev` to the SHA of the version you need. Note that the same SHA should be used for all DataFusion modules. #### Local Documentation Sometimes when building against the latest Github commits for DataFusion you may find that the features you are consuming do not have their documentation public yet. In this case it can be helpful to build the DataFusion documentation locally so that it can be referenced to assist with development. Here is a rough outline for building that documentation locally. - clone https://github.com/apache/arrow-datafusion - change into the `arrow-datafusion` directory - run `cargo doc` - navigate to `target/doc/datafusion/all.html` and open in your desired browser ### Datastructures While working in the Rust codebase there are a few datastructures that you should make yourself familiar with. This section does not aim to verbosely list out all of the datastructure with in the project but rather just the key datastructures that you are likely to encounter while working on almost any feature/issue. The aim is to give you a better overview of the codebase without having to manually dig through the all the source code. - [`PyLogicalPlan`](src/sql/logical.rs) -> [DataFusion LogicalPlan](https://docs.rs/datafusion/latest/datafusion/logical_plan/enum.LogicalPlan.html) - Often encountered in Python code with variable name `rel` - Python serializable umbrella representation of the entire LogicalPlan that was generated by DataFusion - Provides access to `DaskTable` instances and type information for each table - Access to individual nodes in the logical plan tree. Ex: `TableScan` - [`DaskSQLContext`](src/sql.rs) - Analogous to Python `Context` - Contains metadata about the tables, schemas, functions, operators, and configurations that are persent within the current execution context - When adding custom functions/UDFs this is the location that you would register them - Entry point for parsing SQL strings to sql node trees. This is the location Python will begin its interactions with Rust - [`PyExpr`](src/expression.rs) -> [DataFusion Expr](https://docs.rs/datafusion/latest/datafusion/prelude/enum.Expr.html) - Arguably where most of your time will be spent - Represents a single node in sql tree. Ex: `avg(age)` from `SELECT avg(age) FROM people` - Is associate with a single `RexType` - Can contain literal values or represent function calls, `avg()` for example - The expressions "index" in the tree can be retrieved by calling `PyExpr.index()` on an instance. This is useful when mapping frontend column names in Dask code to backend Dataframe columns - Certain `PyExpr`s contain operands. Ex: `2 + 2` would contain 3 operands. 1) A literal `PyExpr` instance with value 2 2) Another literal `PyExpr` instance with a value of 2. 3) A `+` `PyExpr` representing the addition of the 2 literals. - [`DaskSqlOptimizer`](src/sql/optimizer.rs) - Registering location for all Dask-SQL specific logical plan optimizations - Optimizations that are written either custom or use from another source, DataFusion, are registered here in the order they are wished to be executed - Represents functions that modify/convert an original `PyLogicalPlan` into another `PyLogicalPlan` that would be more efficient when running in the underlying Dask framework - [`RelDataType`](src/sql/types/rel_data_type.rs) - Not a fan of this name, was chosen to match existing Calcite logic - Represents a "row" in a table - Contains a list of "columns" that are present in that row - [RelDataTypeField](src/sql/types/rel_data_type_field.rs) - [RelDataTypeField](src/sql/types/rel_data_type_field.rs) - Represents an individual column in a table - Contains: - `qualifier` - schema the field belongs to - `name` - name of the column/field - `data_type` - `DaskTypeMap` instance containing information about the SQL type and underlying Arrow DataType - `index` - location of the field in the LogicalPlan - [DaskTypeMap](src/sql/types.rs) - Maps a conventional SQL type to an underlying Arrow DataType ### Rust Learning Resources - ["The Book"](https://doc.rust-lang.org/book/) - [Lets Get Rusty "LGR" YouTube series](https://www.youtube.com/c/LetsGetRusty) ## Documentation TODO - [ ] SQL Parsing overview diagram - [ ] Architecture diagram - [x] Setup dev environment - [x] Version of Rust and specs - [x] Updating version of datafusion - [x] Building - [x] Rust learning resources - [x] Rust Datastructures local to Dask-SQL - [x] Build DataFusion documentation locally - [ ] Python & Rust with PyO3 - [ ] Types mapping, Arrow datatypes - [ ] RexTypes explaination, show simple query and show it broken down into its parts in a diagram - [ ] Registering tables with DaskSqlContext, also functions - [ ] Creating your own optimizer - [ ] Simple diagram of PyExpr, showing something like 2+2 but broken down into a tree looking diagram ================================================ FILE: Cargo.toml ================================================ [package] name = "dask-sql" repository = "https://github.com/dask-contrib/dask-sql" version = "2024.5.0" description = "Bindings for DataFusion used by Dask-SQL" readme = "README.md" license = "Apache-2.0" edition = "2021" rust-version = "1.72" include = ["/src", "/dask_sql", "/LICENSE.txt", "pyproject.toml", "Cargo.toml", "Cargo.lock"] [dependencies] async-trait = "0.1.78" datafusion-python = { git = "https://github.com/apache/arrow-datafusion-python.git", ref = "da6c183" } env_logger = "0.11" log = "^0.4" pyo3 = { version = "0.19.2", features = ["extension-module", "abi3", "abi3-py39"] } pyo3-log = "0.9.0" [build-dependencies] pyo3-build-config = "0.20.3" [lib] name = "dask_sql" crate-type = ["cdylib", "rlib"] [profile.release] lto = true codegen-units = 1 ================================================ FILE: LICENSE.txt ================================================ MIT LICENCE Copyright (c) 2020 Nils Braun Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: MANIFEST.in ================================================ recursive-include dask_sql *.yaml recursive-include dask_planner * ================================================ FILE: README.md ================================================ **Dask-SQL is currently not in active maintenance, see [#1344](https://github.com/dask-contrib/dask-sql/issues/1344) for more information** [![Conda](https://img.shields.io/conda/v/conda-forge/dask-sql)](https://anaconda.org/conda-forge/dask-sql) [![PyPI](https://img.shields.io/pypi/v/dask-sql?logo=pypi)](https://pypi.python.org/pypi/dask-sql/) [![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/dask-contrib/dask-sql/test.yml?branch=main)](https://github.com/dask-contrib/dask-sql/actions/workflows/test.yml?query=branch%3Amain) [![Read the Docs](https://img.shields.io/readthedocs/dask-sql)](https://dask-sql.readthedocs.io/en/latest/) [![Codecov](https://img.shields.io/codecov/c/github/dask-contrib/dask-sql?logo=codecov)](https://codecov.io/gh/dask-contrib/dask-sql) [![GitHub](https://img.shields.io/github/license/dask-contrib/dask-sql)](https://github.com/dask-contrib/dask-sql/blob/main/LICENSE.txt) [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/dask-contrib/dask-sql-binder/main?urlpath=lab)
SQL + Python
`dask-sql` is a distributed SQL query engine in Python. It allows you to query and transform your data using a mixture of common SQL operations and Python code and also scale up the calculation easily if you need it. * **Combine the power of Python and SQL**: load your data with Python, transform it with SQL, enhance it with Python and query it with SQL - or the other way round. With `dask-sql` you can mix the well known Python dataframe API of `pandas` and `Dask` with common SQL operations, to process your data in exactly the way that is easiest for you. * **Infinite Scaling**: using the power of the great `Dask` ecosystem, your computations can scale as you need it - from your laptop to your super cluster - without changing any line of SQL code. From k8s to cloud deployments, from batch systems to YARN - if `Dask` [supports it](https://docs.dask.org/en/latest/setup.html), so will `dask-sql`. * **Your data - your queries**: Use Python user-defined functions (UDFs) in SQL without any performance drawback and extend your SQL queries with the large number of Python libraries, e.g. machine learning, different complicated input formats, complex statistics. * **Easy to install and maintain**: `dask-sql` is just a pip/conda install away (or a docker run if you prefer). * **Use SQL from wherever you like**: `dask-sql` integrates with your jupyter notebook, your normal Python module or can be used as a standalone SQL server from any BI tool. It even integrates natively with [Apache Hue](https://gethue.com/). * **GPU Support**: `dask-sql` supports running SQL queries on CUDA-enabled GPUs by utilizing [RAPIDS](https://rapids.ai) libraries like [`cuDF`](https://github.com/rapidsai/cudf), enabling accelerated compute for SQL. Read more in the [documentation](https://dask-sql.readthedocs.io/en/latest/).
dask-sql GIF
--- ## Example For this example, we use some data loaded from disk and query them with a SQL command from our python code. Any pandas or dask dataframe can be used as input and ``dask-sql`` understands a large amount of formats (csv, parquet, json,...) and locations (s3, hdfs, gcs,...). ```python import dask.dataframe as dd from dask_sql import Context # Create a context to hold the registered tables c = Context() # Load the data and register it in the context # This will give the table a name, that we can use in queries df = dd.read_csv("...") c.create_table("my_data", df) # Now execute a SQL query. The result is again dask dataframe. result = c.sql(""" SELECT my_data.name, SUM(my_data.x) FROM my_data GROUP BY my_data.name """, return_futures=False) # Show the result print(result) ``` ## Quickstart Have a look into the [documentation](https://dask-sql.readthedocs.io/en/latest/) or start the example notebook on [binder](https://mybinder.org/v2/gh/dask-contrib/dask-sql-binder/main?urlpath=lab). > `dask-sql` is currently under development and does so far not understand all SQL commands (but a large fraction). We are actively looking for feedback, improvements and contributors! ## Installation `dask-sql` can be installed via `conda` (preferred) or `pip` - or in a development environment. ### With `conda` Create a new conda environment or use your already present environment: conda create -n dask-sql conda activate dask-sql Install the package from the `conda-forge` channel: conda install dask-sql -c conda-forge ### With `pip` You can install the package with pip install dask-sql ### For development If you want to have the newest (unreleased) `dask-sql` version or if you plan to do development on `dask-sql`, you can also install the package from sources. git clone https://github.com/dask-contrib/dask-sql.git Create a new conda environment and install the development environment: conda env create -f continuous_integration/environment-3.9.yaml It is not recommended to use `pip` instead of `conda` for the environment setup. After that, you can install the package in development mode pip install -e ".[dev]" The Rust DataFusion bindings are built as part of the `pip install`. Note that if changes are made to the Rust source in `src/`, another build must be run to recompile the bindings. This repository uses [pre-commit](https://pre-commit.com/) hooks. To install them, call pre-commit install ## Testing You can run the tests (after installation) with pytest tests GPU-specific tests require additional dependencies specified in `continuous_integration/gpuci/environment.yaml`. These can be added to the development environment by running ``` conda env update -n dask-sql -f continuous_integration/gpuci/environment.yaml ``` And GPU-specific tests can be run with ``` pytest tests -m gpu --rungpu ``` ## SQL Server `dask-sql` comes with a small test implementation for a SQL server. Instead of rebuilding a full ODBC driver, we re-use the [presto wire protocol](https://github.com/prestodb/presto/wiki/HTTP-Protocol). It is - so far - only a start of the development and missing important concepts, such as authentication. You can test the sql presto server by running (after installation) dask-sql-server or by using the created docker image docker run --rm -it -p 8080:8080 nbraun/dask-sql in one terminal. This will spin up a server on port 8080 (by default) that looks similar to a normal presto database to any presto client. You can test this for example with the default [presto client](https://prestosql.io/docs/current/installation/cli.html): presto --server localhost:8080 Now you can fire simple SQL queries (as no data is loaded by default): => SELECT 1 + 1; EXPR$0 -------- 2 (1 row) You can find more information in the [documentation](https://dask-sql.readthedocs.io/en/latest/pages/server.html). ## CLI You can also run the CLI `dask-sql` for testing out SQL commands quickly: dask-sql --load-test-data --startup (dask-sql) > SELECT * FROM timeseries LIMIT 10; ## How does it work? At the core, `dask-sql` does two things: - translate the SQL query using [DataFusion](https://arrow.apache.org/datafusion) into a relational algebra, which is represented as a logical query plan - similar to many other SQL engines (Hive, Flink, ...) - convert this description of the query into dask API calls (and execute them) - returning a dask dataframe. For the first step, Arrow DataFusion needs to know about the columns and types of the dask dataframes, therefore some Rust code to store this information for dask dataframes are defined in `dask_planner`. After the translation to a relational algebra is done (using `DaskSQLContext.logical_relational_algebra`), the python methods defined in `dask_sql.physical` turn this into a physical dask execution plan by converting each piece of the relational algebra one-by-one. ================================================ FILE: conftest.py ================================================ import dask import pytest pytest_plugins = ["tests.integration.fixtures"] def pytest_addoption(parser): parser.addoption("--rungpu", action="store_true", help="run tests meant for GPU") parser.addoption("--runqueries", action="store_true", help="run test queries") parser.addoption("--data_dir", help="specify file path to the data") parser.addoption("--queries_dir", help="specify file path to the queries") def pytest_runtest_setup(item): # TODO: get pyarrow strings and p2p shuffle working dask.config.set({"dataframe.convert-string": False}) dask.config.set({"dataframe.shuffle.method": "tasks"}) if "gpu" in item.keywords: if not item.config.getoption("--rungpu"): pytest.skip("need --rungpu option to run") # manually enable cudf decimal support dask.config.set({"sql.mappings.decimal_support": "cudf"}) if "queries" in item.keywords and not item.config.getoption("--runqueries"): pytest.skip("need --runqueries option to run") @pytest.fixture(scope="session") def data_dir(request): return request.config.getoption("--data_dir") @pytest.fixture(scope="session") def queries_dir(request): return request.config.getoption("--queries_dir") ================================================ FILE: continuous_integration/docker/cloud.dockerfile ================================================ ARG DOCKER_META_VERSION FROM nbraun/dask-sql:${DOCKER_META_VERSION} RUN conda config --add channels conda-forge \ && /opt/conda/bin/mamba install --freeze-installed -y \ s3fs \ dask-cloudprovider \ && pip install awscli \ && conda clean -ay ENTRYPOINT ["tini", "-g", "--", "/usr/bin/prepare.sh"] ================================================ FILE: continuous_integration/docker/conda.txt ================================================ python>=3.9 dask>=2024.4.1 pandas>=1.4.0 jpype1>=1.0.2 openjdk>=8 maven>=3.6.0 pytest>=6.0.2 pytest-cov>=2.10.1 pytest-xdist mock>=4.0.3 sphinx>=3.2.1 tzlocal>=2.1 fastapi>=0.92.0 httpx>=0.24.1 uvicorn>=0.14 pyarrow>=14.0.1 prompt_toolkit>=3.0.8 pygments>=2.7.1 scikit-learn>=1.0.0 intake>=0.6.0 pre-commit>=2.11.1 black=22.10.0 isort=5.12.0 maturin>=1.3,<1.4 ================================================ FILE: continuous_integration/docker/main.dockerfile ================================================ # Dockerfile for dask-sql running the SQL server # For more information, see https://dask-sql.readthedocs.io/. FROM daskdev/dask:latest LABEL author "Nils Braun " # Install rustc & gcc for compilation of DataFusion planner ADD https://sh.rustup.rs /rustup-init.sh RUN sh /rustup-init.sh -y --default-toolchain=stable --profile=minimal \ && apt-get update \ && apt-get install gcc -y ENV PATH="/root/.cargo/bin:${PATH}" # Install conda dependencies for dask-sql COPY continuous_integration/docker/conda.txt /opt/dask_sql/ RUN mamba install -y \ # build requirements "maturin>=1.3,<1.4" \ # core dependencies "dask>=2024.4.1" \ "pandas>=1.4.0" \ "fastapi>=0.92.0" \ "httpx>=0.24.1" \ "uvicorn>=0.14" \ "tzlocal>=2.1" \ "prompt_toolkit>=3.0.8" \ "pygments>=2.7.1" \ tabulate \ # additional dependencies "pyarrow>=14.0.1" \ "scikit-learn>=1.0.0" \ "intake>=0.6.0" \ && conda clean -ay # install dask-sql COPY Cargo.toml /opt/dask_sql/ COPY Cargo.lock /opt/dask_sql/ COPY pyproject.toml /opt/dask_sql/ COPY setup.cfg /opt/dask_sql/ COPY README.md /opt/dask_sql/ COPY .git /opt/dask_sql/.git COPY src /opt/dask_sql/src COPY dask_sql /opt/dask_sql/dask_sql RUN cd /opt/dask_sql/ \ && CONDA_PREFIX="/opt/conda/" maturin develop # Set the script to execute COPY continuous_integration/scripts/startup_script.py /opt/dask_sql/startup_script.py EXPOSE 8080 ENTRYPOINT [ "/usr/bin/prepare.sh", "/opt/conda/bin/python", "/opt/dask_sql/startup_script.py" ] ================================================ FILE: continuous_integration/environment-3.10.yaml ================================================ name: dask-sql channels: - conda-forge dependencies: - c-compiler - dask>=2024.4.1 - dask-expr>=1.0.11 - docker-py>=7.1.0 - fastapi>=0.92.0 - fugue>=0.7.3 - httpx>=0.24.1 - intake>=0.6.0 - jsonschema - lightgbm - maturin>=1.3,<1.4 - mlflow>=2.10 - mock - numpy>=1.22.4 - pandas>=2 - pre-commit - prompt_toolkit>=3.0.8 - psycopg2 - pyarrow>=14.0.1 - pygments>=2.7.1 - pyhive - pytest-cov - pytest-rerunfailures - pytest-xdist - pytest - python=3.10 - py-xgboost>=2.0.3 - scikit-learn>=1.0.0 - sphinx - sqlalchemy - tpot>=0.12.0 # FIXME: https://github.com/fugue-project/fugue/issues/526 - triad<0.9.2 - tzlocal>=2.1 - uvicorn>=0.14 - zlib ================================================ FILE: continuous_integration/environment-3.11.yaml ================================================ name: dask-sql channels: - conda-forge dependencies: - c-compiler - dask>=2024.4.1 - dask-expr>=1.0.11 - docker-py>=7.1.0 - fastapi>=0.92.0 - fugue>=0.7.3 - httpx>=0.24.1 - intake>=0.6.0 - jsonschema - lightgbm - maturin>=1.3,<1.4 - mlflow>=2.10 - mock - numpy>=1.22.4 - pandas>=2 - pre-commit - prompt_toolkit>=3.0.8 - psycopg2 - pyarrow>=14.0.1 - pygments>=2.7.1 - pyhive - pytest-cov - pytest-rerunfailures - pytest-xdist - pytest - python=3.11 - py-xgboost>=2.0.3 - scikit-learn>=1.0.0 - sphinx - sqlalchemy - tpot>=0.12.0 # FIXME: https://github.com/fugue-project/fugue/issues/526 - triad<0.9.2 - tzlocal>=2.1 - uvicorn>=0.14 - zlib ================================================ FILE: continuous_integration/environment-3.12.yaml ================================================ name: dask-sql channels: - conda-forge dependencies: - c-compiler - dask>=2024.4.1 - dask-expr>=1.0.11 - docker-py>=7.1.0 - fastapi>=0.92.0 - fugue>=0.7.3 - httpx>=0.24.1 - intake>=0.6.0 - jsonschema - lightgbm - maturin>=1.3,<1.4 # TODO: add once mlflow 3.12 builds are available # - mlflow>=2.10 - mock - numpy>=1.22.4 - pandas>=2 - pre-commit - prompt_toolkit>=3.0.8 - psycopg2 - pyarrow>=14.0.1 - pygments>=2.7.1 - pyhive - pytest-cov - pytest-rerunfailures - pytest-xdist - pytest - python=3.12 - py-xgboost>=2.0.3 - scikit-learn>=1.0.0 - sphinx - sqlalchemy # TODO: add once tpot supports python 3.12 # - tpot>=0.12.0 # FIXME: https://github.com/fugue-project/fugue/issues/526 - triad<0.9.2 - tzlocal>=2.1 - uvicorn>=0.14 - zlib ================================================ FILE: continuous_integration/environment-3.9.yaml ================================================ name: dask-sql-py39 channels: - conda-forge dependencies: - c-compiler - dask=2024.4.1 - dask-expr=1.0.11 - docker-py>=7.1.0 - fastapi=0.92.0 - fugue=0.7.3 - httpx=0.24.1 - intake=0.6.0 - jsonschema - lightgbm - maturin=1.3 - mlflow=2.10 - mock - numpy=1.22.4 - pandas=2 - pre-commit - prompt_toolkit=3.0.8 - psycopg2 - pyarrow=14.0.1 - pygments=2.7.1 - pyhive - pytest-cov - pytest-rerunfailures - pytest-xdist - pytest - python=3.9 - py-xgboost=2.0.3 - scikit-learn=1.0.0 - sphinx - sqlalchemy - tpot>=0.12.0 # FIXME: https://github.com/fugue-project/fugue/issues/526 - triad<0.9.2 - tzlocal=2.1 - uvicorn=0.14 - zlib ================================================ FILE: continuous_integration/gpuci/environment-3.10.yaml ================================================ name: dask-sql channels: - rapidsai - rapidsai-nightly - dask/label/dev - conda-forge - nvidia - nodefaults dependencies: - c-compiler - zlib - dask>=2024.4.1 - dask-expr>=1.0.11 - fastapi>=0.92.0 - fugue>=0.7.3 - httpx>=0.24.1 - intake>=0.6.0 - jsonschema - lightgbm - maturin>=1.3,<1.4 - mock - numpy>=1.22.4 - pandas>=2 - pre-commit - prompt_toolkit>=3.0.8 - psycopg2 - pyarrow>=14.0.1 - pygments>=2.7.1 - pyhive - pytest-cov - pytest-rerunfailures - pytest-xdist - pytest - python=3.10 - py-xgboost>=2.0.3 - scikit-learn>=1.0.0 - sphinx - sqlalchemy - tpot>=0.12.0 # FIXME: https://github.com/fugue-project/fugue/issues/526 - triad<0.9.2 - tzlocal>=2.1 - uvicorn>=0.14 # GPU-specific requirements - cudatoolkit=11.8 - cudf=24.06 - cuml=24.06 - dask-cudf=24.06 - dask-cuda=24.06 - ucx-proc=*=gpu - ucx-py=0.38 - xgboost=*=rapidsai_py* - libxgboost=*=rapidsai_h* ================================================ FILE: continuous_integration/gpuci/environment-3.11.yaml ================================================ name: dask-sql channels: - rapidsai - rapidsai-nightly - dask/label/dev - conda-forge - nvidia - nodefaults dependencies: - c-compiler - zlib - dask>=2024.4.1 - dask-expr>=1.0.11 - fastapi>=0.92.0 - fugue>=0.7.3 - httpx>=0.24.1 - intake>=0.6.0 - jsonschema - lightgbm - maturin>=1.3,<1.4 - mock - numpy>=1.22.4 - pandas>=2 - pre-commit - prompt_toolkit>=3.0.8 - psycopg2 - pyarrow>=14.0.1 - pygments>=2.7.1 - pyhive - pytest-cov - pytest-rerunfailures - pytest-xdist - pytest - python=3.11 - py-xgboost>=2.0.3 - scikit-learn>=1.0.0 - sphinx - sqlalchemy - tpot>=0.12.0 # FIXME: https://github.com/fugue-project/fugue/issues/526 - triad<0.9.2 - tzlocal>=2.1 - uvicorn>=0.14 # GPU-specific requirements - cudatoolkit=11.8 - cudf=24.06 - cuml=24.06 - dask-cudf=24.06 - dask-cuda=24.06 - ucx-proc=*=gpu - ucx-py=0.38 - xgboost=*=rapidsai_py* - libxgboost=*=rapidsai_h* ================================================ FILE: continuous_integration/gpuci/environment-3.9.yaml ================================================ name: dask-sql channels: - rapidsai - rapidsai-nightly - dask/label/dev - conda-forge - nvidia - nodefaults dependencies: - c-compiler - zlib - dask>=2024.4.1 - dask-expr>=1.0.11 - fastapi>=0.92.0 - fugue>=0.7.3 - httpx>=0.24.1 - intake>=0.6.0 - jsonschema - lightgbm - maturin>=1.3,<1.4 - mock - numpy>=1.22.4 - pandas>=2 - pre-commit - prompt_toolkit>=3.0.8 - psycopg2 - pyarrow>=14.0.1 - pygments>=2.7.1 - pyhive - pytest-cov - pytest-rerunfailures - pytest-xdist - pytest - python=3.9 - py-xgboost==2.0.3 - scikit-learn>=1.0.0 - sphinx - sqlalchemy - tpot>=0.12.0 # FIXME: https://github.com/fugue-project/fugue/issues/526 - triad<0.9.2 - tzlocal>=2.1 - uvicorn>=0.14 # GPU-specific requirements - cudatoolkit=11.8 - cudf=24.06 - cuml=24.06 - dask-cudf=24.06 - dask-cuda=24.06 - ucx-proc=*=gpu - ucx-py=0.38 - xgboost=*=rapidsai_py* - libxgboost=*=rapidsai_h* ================================================ FILE: continuous_integration/recipe/build.sh ================================================ #!/bin/bash set -ex # See https://github.com/conda-forge/rust-feedstock/blob/master/recipe/build.sh for cc env explanation if [ "$c_compiler" = gcc ] ; then case "$target_platform" in linux-64) rust_env_arch=X86_64_UNKNOWN_LINUX_GNU ;; linux-aarch64) rust_env_arch=AARCH64_UNKNOWN_LINUX_GNU ;; linux-ppc64le) rust_env_arch=POWERPC64LE_UNKNOWN_LINUX_GNU ;; *) echo "unknown target_platform $target_platform" ; exit 1 ;; esac export CARGO_TARGET_${rust_env_arch}_LINKER=$CC fi declare -a _xtra_maturin_args mkdir -p $SRC_DIR/.cargo if [ "$target_platform" = "osx-64" ] ; then cat <> $SRC_DIR/.cargo/config [target.x86_64-apple-darwin] linker = "$CC" rustflags = [ "-C", "link-arg=-undefined", "-C", "link-arg=dynamic_lookup", ] EOF _xtra_maturin_args+=(--target=x86_64-apple-darwin) elif [ "$target_platform" = "osx-arm64" ] ; then cat <> $SRC_DIR/.cargo/config # Required for intermediate codegen stuff [target.x86_64-apple-darwin] linker = "$CC_FOR_BUILD" # Required for final binary artifacts for target [target.aarch64-apple-darwin] linker = "$CC" rustflags = [ "-C", "link-arg=-undefined", "-C", "link-arg=dynamic_lookup", ] EOF _xtra_maturin_args+=(--target=aarch64-apple-darwin) # This variable must be set to the directory containing the target's libpython DSO export PYO3_CROSS_LIB_DIR=$PREFIX/lib # xref: https://github.com/PyO3/pyo3/commit/7beb2720 export PYO3_PYTHON_VERSION=${PY_VER} # xref: https://github.com/conda-forge/python-feedstock/issues/621 sed -i.bak 's,aarch64,arm64,g' $BUILD_PREFIX/venv/lib/os-patch.py sed -i.bak 's,aarch64,arm64,g' $BUILD_PREFIX/venv/lib/platform-patch.py fi maturin build -vv -j "${CPU_COUNT}" --release --strip --manylinux off --interpreter="${PYTHON}" "${_xtra_maturin_args[@]}" "${PYTHON}" -m pip install $SRC_DIR/target/wheels/dask_sql*.whl --no-deps -vv ================================================ FILE: continuous_integration/recipe/conda_build_config.yaml ================================================ c_compiler: - gcc c_compiler_version: - '12' rust_compiler: - rust rust_compiler_version: - '1.72' maturin: - '1.3' xz: # [linux64] - '5' # [linux64] ================================================ FILE: continuous_integration/recipe/meta.yaml ================================================ {% set name = "dask-sql" %} {% set major_minor_patch = environ.get('GIT_DESCRIBE_TAG', '0.0.0.dev').split('.') %} {% set new_patch = major_minor_patch[2] | int + 1 %} {% set version = (major_minor_patch[:2] + [new_patch]) | join('.') + environ.get('VERSION_SUFFIX', '') %} package: name: {{ name|lower }} version: {{ version }} source: git_url: ../.. build: number: {{ GIT_DESCRIBE_NUMBER }} entry_points: - dask-sql-server = dask_sql.server.app:main - dask-sql = dask_sql.cmd:main string: py{{ python | replace(".", "") }}_{{ GIT_DESCRIBE_HASH }}_{{ GIT_DESCRIBE_NUMBER }} requirements: build: - python # [build_platform != target_platform] - cross-python_{{ target_platform }} # [build_platform != target_platform] - maturin # [build_platform != target_platform] - {{ compiler('c') }} - {{ compiler('rust') }} host: - pip - python - maturin - xz # [linux64] run: - python - dask >=2024.4.1 - pandas >=1.4.0 - fastapi >=0.92.0 - httpx >=0.24.1 - uvicorn >=0.14 - tzlocal >=2.1 - prompt-toolkit >=3.0.8 - pygments >=2.7.1 - tabulate test: imports: - dask_sql commands: - pip check - dask-sql-server --help - dask-sql --help requires: - pip about: home: https://github.com/dask-contrib/dask-sql/ summary: SQL query layer for Dask license: MIT license_file: LICENSE.txt ================================================ FILE: continuous_integration/recipe/run_test.py ================================================ import dask.dataframe as dd import pandas as pd from dask_sql import Context c = Context() data = """ name,x Alice,34 Bob, """ df = pd.DataFrame({"name": ["Alice", "Bob", "Chris"] * 100, "x": list(range(300))}) ddf = dd.from_pandas(df, npartitions=10) # This needs to be temprarily disabled since this query requires features that are not yet implemented # c.create_table("my_data", ddf) # got = c.sql( # """ # SELECT # my_data.name, # SUM(my_data.x) AS "S" # FROM # my_data # GROUP BY # my_data.name # """ # ) # expect = pd.DataFrame({"name": ["Alice", "Bob", "Chris"], "S": [14850, 14950, 15050]}) # dd.assert_eq(got, expect) ================================================ FILE: continuous_integration/scripts/startup_script.py ================================================ from dask_sql.server.app import main if __name__ == "__main__": main() ================================================ FILE: continuous_integration/scripts/update-dependencies.sh ================================================ #!/bin/bash UPDATE_ALL_CARGO_DEPS="${UPDATE_ALL_CARGO_DEPS:-true}" # Update datafusion dependencies in the dask-planner to the latest revision from the default branch sed -i -r 's/^datafusion-([a-z]+).*/datafusion-\1 = { git = "https:\/\/github.com\/apache\/arrow-datafusion-python\/" }/g' Cargo.toml if [ "$UPDATE_ALL_CARGO_DEPS" = true ] ; then cargo update fi ================================================ FILE: dask_sql/__init__.py ================================================ # FIXME: can we modify TLS model of Rust object to avoid aarch64 glibc bug? # https://github.com/dask-contrib/dask-sql/issues/1169 from . import _datafusion_lib # isort:skip import importlib.metadata from dask.config import set from . import config from .cmd import cmd_loop from .context import Context from .datacontainer import Statistics from .server.app import run_server # TODO: get pyarrow strings and p2p shuffle working set(dataframe__convert_string=False, dataframe__shuffle__method="tasks") __version__ = importlib.metadata.version(__name__) __all__ = [__version__, cmd_loop, Context, run_server, Statistics] ================================================ FILE: dask_sql/_compat.py ================================================ import prompt_toolkit from packaging.version import parse as parseVersion _prompt_toolkit_version = parseVersion(prompt_toolkit.__version__) # TODO: remove if prompt-toolkit min version gets bumped PIPE_INPUT_CONTEXT_MANAGER = _prompt_toolkit_version >= parseVersion("3.0.29") ================================================ FILE: dask_sql/cmd.py ================================================ import logging import os import sys import tempfile import traceback from argparse import ArgumentParser from functools import partial from typing import Union import pandas as pd from dask.datasets import timeseries from dask.distributed import Client, as_completed from prompt_toolkit.auto_suggest import AutoSuggestFromHistory from prompt_toolkit.completion import WordCompleter from prompt_toolkit.history import FileHistory from prompt_toolkit.shortcuts import ProgressBar from pygments.lexers.sql import SqlLexer try: # prompt_toolkit version >= 2 from prompt_toolkit.lexers import PygmentsLexer except ImportError: # pragma: no cover # prompt_toolkit version < 2 from prompt_toolkit.layout.lexers import PygmentsLexer from dask_sql.context import Context meta_command_completer = WordCompleter( ["\\l", "\\d?", "\\dt", "\\df", "\\de", "\\dm", "\\conninfo", "quit"] ) class CompatiblePromptSession: """ Session object wrapper for the prompt_toolkit module In the version jump from 1 to 2, the prompt_toolkit introduced a PromptSession object. Some environments however (e.g. google collab) still rely on an older prompt_toolkit version, so we try to support both versions with this wrapper object. All it does is export a `prompt` function. """ def __init__(self, lexer) -> None: # pragma: no cover # make sure everytime dask-sql uses same history file kwargs = { "lexer": lexer, "history": FileHistory( os.path.join(tempfile.gettempdir(), "dask-sql-history") ), "auto_suggest": AutoSuggestFromHistory(), "completer": meta_command_completer, } try: # Version >= 2.0.1: we can use the session object from prompt_toolkit import PromptSession session = PromptSession(**kwargs) self.prompt = session.prompt except ImportError: # Version < 2.0: there is no session object from prompt_toolkit.shortcuts import prompt self.prompt = partial(prompt, **kwargs) def _display_markdown(content, **kwargs): df = pd.DataFrame(content, **kwargs) print(df.to_markdown(tablefmt="fancy_grid")) def _parse_meta_command(sql): command, _, arg = sql.partition(" ") return command, arg.strip() def _meta_commands(sql: str, context: Context, client: Client) -> Union[bool, Client]: """ parses metacommands and prints their result returns True if meta commands detected """ cmd, schema_name = _parse_meta_command(sql) available_commands = [ ["\\l", "List schemas"], ["\\d?, help, ?", "Show available commands"], ["\\conninfo", "Show Dask cluster info"], ["\\dt [schema]", "List tables"], ["\\df [schema]", "List functions"], ["\\dm [schema]", "List models"], ["\\de [schema]", "List experiments"], ["\\dss [schema]", "Switch schema"], ["\\dsc [dask scheduler address]", "Switch Dask cluster"], ["quit", "Quits dask-sql-cli"], ] if cmd == "\\dsc": # Switch Dask cluster _, scheduler_address = _parse_meta_command(sql) client = Client(scheduler_address) return client # pragma: no cover schema_name = schema_name or context.schema_name if cmd == "\\d?" or cmd == "help" or cmd == "?": _display_markdown(available_commands, columns=["Commands", "Description"]) elif cmd == "\\l": _display_markdown(context.schema.keys(), columns=["Schemas"]) elif cmd == "\\dt": _display_markdown(context.schema[schema_name].tables.keys(), columns=["Tables"]) elif cmd == "\\df": _display_markdown( context.schema[schema_name].functions.keys(), columns=["Functions"] ) elif cmd == "\\de": _display_markdown( context.schema[schema_name].experiments.keys(), columns=["Experiments"] ) elif cmd == "\\dm": _display_markdown(context.schema[schema_name].models.keys(), columns=["Models"]) elif cmd == "\\conninfo": cluster_info = [ ["Dask scheduler", client.scheduler.__dict__["addr"]], ["Dask dashboard", client.dashboard_link], ["Cluster status", client.status], ["Dask workers", len(client.cluster.workers)], ] _display_markdown( cluster_info, columns=["components", "value"] ) # pragma: no cover elif cmd == "\\dss": if schema_name in context.schema: context.schema_name = schema_name else: print(f"Schema {schema_name} not available") elif cmd == "quit": print("Quitting dask-sql ...") client.close() # for safer side sys.exit() elif cmd.startswith("\\"): print( f"The meta command {cmd} not available, please use commands from below list" ) _display_markdown(available_commands, columns=["Commands", "Description"]) else: # nothing detected probably not a meta command return False return True def cmd_loop( context: Context = None, client: Client = None, startup=False, log_level=None, ): # pragma: no cover """ Run a REPL for answering SQL queries using ``dask-sql``. Every SQL expression that ``dask-sql`` understands can be used here. Args: context (:obj:`dask_sql.Context`): If set, use this context instead of an empty one. client (:obj:`dask.distributed.Client`): If set, use this dask client instead of a new one. startup (:obj:`bool`): Whether to wait until Apache Calcite was loaded log_level: (:obj:`str`): The log level of the server and dask-sql Example: It is possible to run a REPL by using the CLI script in ``dask-sql`` or by calling this function directly in your user code: .. code-block:: python from dask_sql import cmd_loop # Create your pre-filled context c = Context() ... cmd_loop(context=c) Of course, it is also possible to call the usual ``CREATE TABLE`` commands. """ pd.set_option("display.max_rows", None) pd.set_option("display.max_columns", None) pd.set_option("display.width", None) pd.set_option("display.max_colwidth", None) logging.basicConfig(level=log_level) client = client or Client() context = context or Context() if startup: context.sql("SELECT 1 + 1").compute() session = CompatiblePromptSession(lexer=PygmentsLexer(SqlLexer)) while True: try: text = session.prompt("(dask-sql) > ") except KeyboardInterrupt: continue except EOFError: break text = text.rstrip(";").strip() if not text: continue meta_command_detected = _meta_commands(text, context=context, client=client) if isinstance(meta_command_detected, Client): client = meta_command_detected if not meta_command_detected: try: df = context.sql(text, return_futures=True) if df is not None: # some sql commands returns None df = df.persist() # Now turn it into a list of futures futures = client.futures_of(df) with ProgressBar() as pb: for _ in pb( as_completed(futures), total=len(futures), label="Executing" ): continue df = df.compute() print(df.to_markdown(tablefmt="fancy_grid")) except Exception: traceback.print_exc() def main(): # pragma: no cover parser = ArgumentParser() parser.add_argument( "--scheduler-address", default=None, help="Connect to this dask scheduler if given", ) parser.add_argument( "--log-level", default=None, help="Set the log level of the server. Defaults to info.", choices=["DEBUG", "INFO", "WARNING", "ERROR"], ) parser.add_argument( "--load-test-data", default=False, action="store_true", help="Preload some test data.", ) parser.add_argument( "--startup", default=False, action="store_true", help="Wait until Apache Calcite was properly loaded", ) args = parser.parse_args() client = None if args.scheduler_address: client = Client(args.scheduler_address) context = Context() if args.load_test_data: df = timeseries(freq="1d").reset_index(drop=False) context.create_table("timeseries", df.persist()) cmd_loop( context=context, client=client, startup=args.startup, log_level=args.log_level ) if __name__ == "__main__": main() ================================================ FILE: dask_sql/config.py ================================================ import os import dask import yaml fn = os.path.join(os.path.dirname(__file__), "sql.yaml") with open(fn) as f: defaults = yaml.safe_load(f) dask.config.update_defaults(defaults) dask.config.ensure_file(source=fn, comment=True) ================================================ FILE: dask_sql/context.py ================================================ import asyncio import inspect import logging from collections import Counter from typing import Any, Callable, Union import dask.dataframe as dd import pandas as pd from dask import config as dask_config from dask.base import optimize from dask.utils_test import hlg_layer from dask_sql._datafusion_lib import ( DaskSchema, DaskSQLContext, DaskSQLOptimizerConfig, DaskTable, DFOptimizationException, DFParsingException, LogicalPlan, ) try: from dask_sql.physical.utils.statistics import parquet_statistics except ModuleNotFoundError: parquet_statistics = None try: import dask_cuda # noqa: F401 except ImportError: # pragma: no cover pass from dask_sql import input_utils from dask_sql.datacontainer import ( UDF, DataContainer, FunctionDescription, SchemaContainer, Statistics, ) from dask_sql.input_utils import InputType, InputUtil from dask_sql.integrations.ipython import ipython_integration from dask_sql.mappings import python_to_sql_type from dask_sql.physical.rel import RelConverter, custom, logical from dask_sql.physical.rex import RexConverter, core from dask_sql.utils import ParsingException logger = logging.getLogger(__name__) class Context: """ Main object to communicate with ``dask_sql``. It holds a store of all registered data frames (= tables) and can convert SQL queries to dask data frames. The tables in these queries are referenced by the name, which is given when registering a dask dataframe. Example: .. code-block:: python from dask_sql import Context c = Context() # Register a table c.create_table("my_table", df) # Now execute an SQL query. The result is a dask dataframe result = c.sql("SELECT a, b FROM my_table") # Trigger the computation (or use the data frame for something else) result.compute() Usually, you will only ever have a single context in your program. See also: :func:`sql` :func:`create_table` """ DEFAULT_CATALOG_NAME = "dask_sql" DEFAULT_SCHEMA_NAME = "root" def __init__(self, logging_level=logging.INFO): """ Create a new context. """ # Set the logging level for this SQL context logging.basicConfig(level=logging_level) # Name of the root catalog self.catalog_name = self.DEFAULT_CATALOG_NAME # Name of the root schema self.schema_name = self.DEFAULT_SCHEMA_NAME # All schema information self.schema = {self.schema_name: SchemaContainer(self.schema_name)} # A started SQL server (useful for jupyter notebooks) self.sql_server = None # Create the `DaskSQLOptimizerConfig` Rust context optimizer_config = DaskSQLOptimizerConfig( dask_config.get("sql.dynamic_partition_pruning"), dask_config.get("sql.fact_dimension_ratio"), dask_config.get("sql.max_fact_tables"), dask_config.get("sql.preserve_user_order"), dask_config.get("sql.filter_selectivity"), ) # Create the `DaskSQLContext` Rust context self.context = DaskSQLContext( self.catalog_name, self.schema_name, optimizer_config ) self.context.register_schema(self.schema_name, DaskSchema(self.schema_name)) # # Register any default plugins, if nothing was registered before. RelConverter.add_plugin_class(logical.DaskAggregatePlugin, replace=False) RelConverter.add_plugin_class(logical.DaskCrossJoinPlugin, replace=False) RelConverter.add_plugin_class(logical.DaskEmptyRelationPlugin, replace=False) RelConverter.add_plugin_class(logical.DaskFilterPlugin, replace=False) RelConverter.add_plugin_class(logical.DaskJoinPlugin, replace=False) RelConverter.add_plugin_class(logical.DaskLimitPlugin, replace=False) RelConverter.add_plugin_class(logical.DaskProjectPlugin, replace=False) RelConverter.add_plugin_class(logical.DaskSortPlugin, replace=False) RelConverter.add_plugin_class(logical.DaskTableScanPlugin, replace=False) RelConverter.add_plugin_class(logical.DaskUnionPlugin, replace=False) RelConverter.add_plugin_class(logical.DaskValuesPlugin, replace=False) RelConverter.add_plugin_class(logical.DaskWindowPlugin, replace=False) RelConverter.add_plugin_class(logical.SamplePlugin, replace=False) RelConverter.add_plugin_class(logical.ExplainPlugin, replace=False) RelConverter.add_plugin_class(logical.SubqueryAlias, replace=False) RelConverter.add_plugin_class(custom.AnalyzeTablePlugin, replace=False) RelConverter.add_plugin_class(custom.CreateExperimentPlugin, replace=False) RelConverter.add_plugin_class(custom.CreateModelPlugin, replace=False) RelConverter.add_plugin_class(custom.CreateCatalogSchemaPlugin, replace=False) RelConverter.add_plugin_class(custom.CreateMemoryTablePlugin, replace=False) RelConverter.add_plugin_class(custom.CreateTablePlugin, replace=False) RelConverter.add_plugin_class(custom.DropModelPlugin, replace=False) RelConverter.add_plugin_class(custom.DropSchemaPlugin, replace=False) RelConverter.add_plugin_class(custom.DropTablePlugin, replace=False) RelConverter.add_plugin_class(custom.ExportModelPlugin, replace=False) RelConverter.add_plugin_class(custom.PredictModelPlugin, replace=False) RelConverter.add_plugin_class(custom.ShowColumnsPlugin, replace=False) RelConverter.add_plugin_class(custom.DescribeModelPlugin, replace=False) RelConverter.add_plugin_class(custom.ShowModelsPlugin, replace=False) RelConverter.add_plugin_class(custom.ShowSchemasPlugin, replace=False) RelConverter.add_plugin_class(custom.ShowTablesPlugin, replace=False) RelConverter.add_plugin_class(custom.UseSchemaPlugin, replace=False) RelConverter.add_plugin_class(custom.AlterSchemaPlugin, replace=False) RelConverter.add_plugin_class(custom.AlterTablePlugin, replace=False) RelConverter.add_plugin_class(custom.DistributeByPlugin, replace=False) RexConverter.add_plugin_class(core.RexAliasPlugin, replace=False) RexConverter.add_plugin_class(core.RexCallPlugin, replace=False) RexConverter.add_plugin_class(core.RexInputRefPlugin, replace=False) RexConverter.add_plugin_class(core.RexLiteralPlugin, replace=False) RexConverter.add_plugin_class(core.RexScalarSubqueryPlugin, replace=False) InputUtil.add_plugin_class(input_utils.DaskInputPlugin, replace=False) InputUtil.add_plugin_class(input_utils.PandasLikeInputPlugin, replace=False) InputUtil.add_plugin_class(input_utils.HiveInputPlugin, replace=False) InputUtil.add_plugin_class(input_utils.IntakeCatalogInputPlugin, replace=False) InputUtil.add_plugin_class(input_utils.SqlalchemyHiveInputPlugin, replace=False) # needs to be the last entry, as it only checks for string InputUtil.add_plugin_class(input_utils.LocationInputPlugin, replace=False) def create_table( self, table_name: str, input_table: InputType, format: str = None, persist: bool = False, schema_name: str = None, statistics: Statistics = None, gpu: bool = False, **kwargs, ): """ Registering a (dask/pandas) table makes it usable in SQL queries. The name you give here can be used as table name in the SQL later. Please note, that the table is stored as it is now. If you change the table later, you need to re-register. Instead of passing an already loaded table, it is also possible to pass a string to a storage location. The library will then try to load the data using one of `dask's read methods `_. If the file format can not be deduced automatically, it is also possible to specify it via the ``format`` parameter. Typical file formats are csv or parquet. Any additional parameters will get passed on to the read method. Please note that some file formats require additional libraries. By default, the data will be lazily loaded. If you would like to load the data directly into memory you can do so by setting persist=True. See :ref:`data_input` for more information. Example: This code registers a data frame as table "data" and then uses it in a query. .. code-block:: python c.create_table("data", df) df_result = c.sql("SELECT a, b FROM data") This code reads a file from disk. Please note that we assume that the file(s) are reachable under this path from every node in the cluster .. code-block:: python c.create_table("data", "/home/user/data.csv") df_result = c.sql("SELECT a, b FROM data") This example reads from a hive table. .. code-block:: python from pyhive.hive import connect cursor = connect("localhost", 10000).cursor() c.create_table("data", cursor, hive_table_name="the_name_in_hive") df_result = c.sql("SELECT a, b FROM data") Args: table_name: (:obj:`str`): Under which name should the new table be addressable input_table (:class:`dask.dataframe.DataFrame` or :class:`pandas.DataFrame` or :obj:`str` or :class:`hive.Cursor`): The data frame/location/hive connection to register. format (:obj:`str`): Only used when passing a string into the ``input`` parameter. Specify the file format directly here if it can not be deduced from the extension. If set to "memory", load the data from a published dataset in the dask cluster. persist (:obj:`bool`): Only used when passing a string into the ``input`` parameter. Set to true to turn on loading the file data directly into memory. schema_name: (:obj:`str`): in which schema to create the table. By default, will use the currently selected schema. statistics: (:obj:`Statistics`): if given, use these statistics during the cost-based optimization. gpu: (:obj:`bool`): if set to true, use dask-cudf to run the data frame calculations on your GPU. Please note that the GPU support is currently not covering all of dask-sql's SQL language. **kwargs: Additional arguments for specific formats. See :ref:`data_input` for more information. """ logger.debug( f"Creating table: '{table_name}' of format type '{format}' in schema '{schema_name}'" ) schema_name = schema_name or self.schema_name dc = InputUtil.to_dc( input_table, table_name=table_name, format=format, persist=persist, gpu=gpu, **kwargs, ) if type(input_table) == str: dc.filepath = input_table self.schema[schema_name].filepaths[table_name.lower()] = input_table elif hasattr(input_table, "dask") and dd.utils.is_dataframe_like(input_table): try: if dd._dask_expr_enabled(): from dask_expr.io.parquet import ReadParquet dask_filepath = None operations = input_table.find_operations(ReadParquet) for op in operations: dask_filepath = op._args[0] else: dask_filepath = hlg_layer( input_table.dask, "read-parquet" ).creation_info["args"][0] dc.filepath = dask_filepath self.schema[schema_name].filepaths[table_name.lower()] = dask_filepath except KeyError: logger.debug("Expected 'read-parquet' layer") if parquet_statistics and not dd._dask_expr_enabled() and not statistics: statistics = parquet_statistics(dc.df) if statistics: row_count = 0 for d in statistics: row_count += d["num-rows"] statistics = Statistics(row_count) if not statistics: statistics = Statistics(float("nan")) dc.statistics = statistics self.schema[schema_name].tables[table_name.lower()] = dc self.schema[schema_name].statistics[table_name.lower()] = statistics def drop_table(self, table_name: str, schema_name: str = None): """ Remove a table with the given name from the registered tables. This will also delete the dataframe. Args: table_name: (:obj:`str`): Which table to remove. """ schema_name = schema_name or self.schema_name del self.schema[schema_name].tables[table_name] def drop_schema(self, schema_name: str): """ Remove a schema with the given name from the registered schemas. This will also delete all tables, functions etc. Args: schema_name: (:obj:`str`): Which schema to remove. """ if schema_name == self.DEFAULT_SCHEMA_NAME: raise RuntimeError(f"Default Schema `{schema_name}` cannot be deleted") del self.schema[schema_name] if self.schema_name == schema_name: self.schema_name = self.DEFAULT_SCHEMA_NAME def register_function( self, f: Callable, name: str, parameters: list[tuple[str, type]], return_type: type, replace: bool = False, schema_name: str = None, row_udf: bool = False, ): """ Register a custom function with the given name. The function can be used (with this name) in every SQL queries from now on - but only for scalar operations (no aggregations). This means, if you register a function "f", you can now call .. code-block:: sql SELECT f(x) FROM df Please keep in mind that you can only have one function with the same name, regardless of whether it is an aggregation or a scalar function. By default, attempting to register two functions with the same name will raise an error; setting `replace=True` will give precedence to the most recently registered function. For the registration, you need to supply both the list of parameter and parameter types as well as the return type. Use `numpy dtypes `_ if possible. More information: :ref:`custom` Example: This example registers a function "f", which calculates the square of an integer and applies it to the column ``x``. .. code-block:: python def f(x): return x ** 2 c.register_function(f, "f", [("x", np.int64)], np.int64) sql = "SELECT f(x) FROM df" df_result = c.sql(sql) Example of overwriting two functions with the same name: This example registers a different function "f", which calculates the floor division of an integer and applies it to the column ``x``. It also shows how to overwrite the previous function with the replace parameter. .. code-block:: python def f(x): return x // 2 c.register_function(f, "f", [("x", np.int64)], np.int64, replace=True) sql = "SELECT f(x) FROM df" df_result = c.sql(sql) Args: f (:obj:`Callable`): The function to register name (:obj:`str`): Under which name should the new function be addressable in SQL parameters (:obj:`List[Tuple[str, type]]`): A list ot tuples of parameter name and parameter type. Use `numpy dtypes `_ if possible. This function is sensitive to the order of specified parameters when `row_udf=True`, and it is assumed that column arguments are specified in order, followed by scalar arguments. return_type (:obj:`type`): The return type of the function replace (:obj:`bool`): If `True`, do not raise an error if a function with the same name is already present; instead, replace the original function. Default is `False`. See also: :func:`register_aggregation` """ self._register_callable( f, name, aggregation=False, parameters=parameters, return_type=return_type, replace=replace, schema_name=schema_name, row_udf=row_udf, ) def register_aggregation( self, f: dd.Aggregation, name: str, parameters: list[tuple[str, type]], return_type: type, replace: bool = False, schema_name: str = None, ): """ Register a custom aggregation with the given name. The aggregation can be used (with this name) in every SQL queries from now on - but only for aggregation operations (no scalar function calls). This means, if you register a aggregation "fagg", you can now call .. code-block:: sql SELECT fagg(y) FROM df GROUP BY x Please note that you can always only have one function with the same name; no matter if it is an aggregation or scalar function. For the registration, you need to supply both the list of parameter and parameter types as well as the return type. Use `numpy dtypes `_ if possible. More information: :ref:`custom` Example: The following code registers a new aggregation "fagg", which computes the sum of a column and uses it on the ``y`` column. .. code-block:: python fagg = dd.Aggregation("fagg", lambda x: x.sum(), lambda x: x.sum()) c.register_aggregation(fagg, "fagg", [("x", np.float64)], np.float64) sql = "SELECT fagg(y) FROM df GROUP BY x" df_result = c.sql(sql) Args: f (:class:`dask.dataframe.Aggregate`): The aggregate to register. See `the dask documentation `_ for more information. name (:obj:`str`): Under which name should the new aggregate be addressable in SQL parameters (:obj:`List[Tuple[str, type]]`): A list ot tuples of parameter name and parameter type. Use `numpy dtypes `_ if possible. return_type (:obj:`type`): The return type of the function replace (:obj:`bool`): Do not raise an error if the function is already present See also: :func:`register_function` """ self._register_callable( f, name, aggregation=True, parameters=parameters, return_type=return_type, replace=replace, schema_name=schema_name, ) def sql( self, sql: Any, return_futures: bool = True, dataframes: dict[str, Union[dd.DataFrame, pd.DataFrame]] = None, gpu: bool = False, config_options: dict[str, Any] = None, ) -> Union[dd.DataFrame, pd.DataFrame]: """ Query the registered tables with the given SQL. The SQL follows approximately the postgreSQL standard - however, not all operations are already implemented. In general, only select statements (no data manipulation) works. For more information, see :ref:`sql`. Example: In this example, a query is called using the registered tables and then executed using dask. .. code-block:: python result = c.sql("SELECT a, b FROM my_table") print(result.compute()) Args: sql (:obj:`str`): The query string to execute return_futures (:obj:`bool`): Return the unexecuted dask dataframe or the data itself. Defaults to returning the dask dataframe. dataframes (:obj:`Dict[str, dask.dataframe.DataFrame]`): additional Dask or pandas dataframes to register before executing this query gpu (:obj:`bool`): Whether or not to load the additional Dask or pandas dataframes (if any) on GPU; requires cuDF / dask-cuDF if enabled. Defaults to False. config_options (:obj:`Dict[str,Any]`): Specific configuration options to pass during query execution Returns: :obj:`dask.dataframe.DataFrame`: the created data frame of this query. """ with dask_config.set(config_options): if dataframes is not None: for df_name, df in dataframes.items(): self.create_table(df_name, df, gpu=gpu) if isinstance(sql, str): rel, _ = self._get_ral(sql) elif isinstance(sql, LogicalPlan): rel = sql else: raise RuntimeError( f"Encountered unsupported `LogicalPlan` sql type: {type(sql)}" ) return self._compute_table_from_rel(rel, return_futures) def explain( self, sql: str, dataframes: dict[str, Union[dd.DataFrame, pd.DataFrame]] = None, gpu: bool = False, ) -> str: """ Return the stringified relational algebra that this query will produce once triggered (with ``sql()``). Helpful to understand the inner workings of dask-sql, but typically not needed to query your data. If the query is of DDL type (e.g. CREATE TABLE or DESCRIBE SCHEMA), no relational algebra plan is created and therefore nothing returned. Args: sql (:obj:`str`): The query string to use dataframes (:obj:`Dict[str, dask.dataframe.DataFrame]`): additional Dask or pandas dataframes to register before executing this query gpu (:obj:`bool`): Whether or not to load the additional Dask or pandas dataframes (if any) on GPU; requires cuDF / dask-cuDF if enabled. Defaults to False. Returns: :obj:`str`: a description of the created relational algebra. """ dynamic_partition_pruning = dask_config.get("sql.dynamic_partition_pruning") if not dask_config.get("sql.optimizer.verbose"): dask_config.set({"sql.dynamic_partition_pruning": False}) if dataframes is not None: for df_name, df in dataframes.items(): self.create_table(df_name, df, gpu=gpu) _, rel_string = self._get_ral(sql) dask_config.set({"sql.dynamic_partition_pruning": dynamic_partition_pruning}) return rel_string def visualize(self, sql: str, filename="mydask.png") -> None: # pragma: no cover """Visualize the computation of the given SQL into the png""" result = self.sql(sql, return_futures=True) (result,) = optimize(result) result.visualize(filename) def create_schema(self, schema_name: str): """ Create a new schema in the database. Args: schema_name (:obj:`str`): The name of the schema to create """ self.schema[schema_name] = SchemaContainer(schema_name) def alter_schema(self, old_schema_name, new_schema_name): """ Alter schema Args: old_schema_name: new_schema_name: """ self.schema[new_schema_name] = self.schema.pop(old_schema_name) def alter_table(self, old_table_name, new_table_name, schema_name=None): """ Alter Table Args: old_table_name: new_table_name: schema_name: """ if schema_name is None: schema_name = self.schema_name self.schema[schema_name].tables[new_table_name] = self.schema[ schema_name ].tables.pop(old_table_name) def register_experiment( self, experiment_name: str, experiment_results: pd.DataFrame, schema_name: str = None, ): schema_name = schema_name or self.schema_name self.schema[schema_name].experiments[ experiment_name.lower() ] = experiment_results def register_model( self, model_name: str, model: Any, training_columns: list[str], schema_name: str = None, ): """ Add a model to the model registry. A model can be anything which has a `.predict` function that transforms a Dask dataframe into predicted labels (as a Dask series). After model registration, the model can be used in calls to `SELECT ... FROM PREDICT` with the given name. Instead of creating your own model and register it, you can also train a model directly in dask-sql. See the SQL command `CrEATE MODEL`. Args: model_name (:obj:`str`): The name of the model model: The model to store training_columns: (list of str): The names of the columns which were used during the training. """ schema_name = schema_name or self.schema_name self.schema[schema_name].models[model_name.lower()] = (model, training_columns) def ipython_magic( self, auto_include=False, disable_highlighting=True ): # pragma: no cover """ Register a new ipython/jupyter magic function "sql" which sends its input as string to the :func:`sql` function. After calling this magic function in a Jupyter notebook or an IPython shell, you can write .. code-block:: python %sql SELECT * from data or .. code-block:: python %%sql SELECT * from data instead of .. code-block:: python c.sql("SELECT * from data") Args: auto_include (:obj:`bool`): If set to true, automatically create a table for every pandas or Dask dataframe in the calling context. That means, if you define a dataframe in your jupyter notebook you can use it with the same name in your sql call. Use this setting with care as any defined dataframe can easily override tables created via `CREATE TABLE`. .. code-block:: python df = ... # Later, without any calls to create_table %%sql SELECT * FROM df disable_highlighting (:obj:`bool`): If set to true, automatically disable syntax highlighting. If you are working in jupyter lab, diable_highlighting must be set to true to enable ipython_magic functionality. If you are working in a classic jupyter notebook, you may set disable_highlighting=False if desired. """ ipython_integration( self, auto_include=auto_include, disable_highlighting=disable_highlighting ) def run_server(self, **kwargs): # pragma: no cover """ Run a HTTP server for answering SQL queries using ``dask-sql``. See :ref:`server` for more information. Args: client (:obj:`dask.distributed.Client`): If set, use this dask client instead of a new one. host (:obj:`str`): The host interface to listen on (defaults to all interfaces) port (:obj:`int`): The port to listen on (defaults to 8080) log_level: (:obj:`str`): The log level of the server and dask-sql """ from dask_sql.server.app import run_server self.stop_server() self.server = run_server(**kwargs) def stop_server(self): # pragma: no cover """ Stop a SQL server started by ``run_server``. """ if self.sql_server is not None: loop = asyncio.get_event_loop() assert loop loop.create_task(self.sql_server.shutdown()) self.sql_server = None def fqn(self, tbl: "DaskTable") -> tuple[str, str]: """ Return the fully qualified name of an object, maybe including the schema name. Args: tbl (:obj:`DaskTable`): The Rust DaskTable instance of the view or table. Returns: :obj:`tuple` of :obj:`str`: The fully qualified name of the object """ schema_name, table_name = tbl.getSchema(), tbl.getTableName() if schema_name is None or schema_name == "": schema_name = self.schema_name return schema_name, table_name def _prepare_schemas(self): """ Create a list of schemas filled with the dataframes and functions we have currently in our schema list """ logger.debug( f"There are {len(self.schema)} existing schema(s): {self.schema.keys()}" ) schema_list = [] for schema_name, schema in self.schema.items(): logger.debug(f"Preparing Schema: '{schema_name}'") rust_schema = DaskSchema(schema_name) if not schema.tables: logger.warning("No tables are registered.") for name, dc in schema.tables.items(): row_count = ( float(schema.statistics[name].row_count) if name in schema.statistics else float(0) ) filepath = schema.filepaths[name] if name in schema.filepaths else None df = dc.df columns = df.columns cc = dc.column_container if not dask_config.get("sql.identifier.case_sensitive"): columns = [col.lower() for col in columns] cc = cc.rename_handle_duplicates(df.columns, columns) dc.column_container = cc column_type_mapping = list( zip(columns, map(python_to_sql_type, df.dtypes)) ) table = DaskTable( schema_name, name, row_count, column_type_mapping, filepath ) rust_schema.add_table(table) if not schema.functions: logger.debug("No custom functions defined.") for function_description in schema.function_lists: name = function_description.name sql_return_type = function_description.return_type sql_parameters = function_description.parameters if function_description.aggregation: logger.debug(f"Adding function '{name}' to schema as aggregation.") rust_schema.add_or_overload_function( name, [param[1].getDataType() for param in sql_parameters], sql_return_type.getDataType(), True, ) else: logger.debug( f"Adding function '{name}' to schema as scalar function." ) rust_schema.add_or_overload_function( name, [param[1].getDataType() for param in sql_parameters], sql_return_type.getDataType(), False, ) schema_list.append(rust_schema) return schema_list def _get_ral(self, sql): """Helper function to turn the sql query into a relational algebra and resulting column names""" logger.debug(f"Entering _get_ral('{sql}')") optimizer_config = DaskSQLOptimizerConfig( dask_config.get("sql.dynamic_partition_pruning"), dask_config.get("sql.fact_dimension_ratio"), dask_config.get("sql.max_fact_tables"), dask_config.get("sql.preserve_user_order"), dask_config.get("sql.filter_selectivity"), ) self.context.set_optimizer_config(optimizer_config) # get the schema of what we currently have registered schemas = self._prepare_schemas() for schema in schemas: self.context.register_schema(schema.name, schema) try: sqlTree = self.context.parse_sql(sql) except DFParsingException as pe: raise ParsingException(sql, str(pe)) logger.debug(f"_get_ral -> sqlTree: {sqlTree}") rel = sqlTree # TODO: Need to understand if this list here is actually needed? For now just use the first entry. if len(sqlTree) > 1: raise RuntimeError( f"Multiple 'Statements' encountered for SQL {sql}. Please share this with the dev team!" ) try: nonOptimizedRel = self.context.logical_relational_algebra(sqlTree[0]) except DFParsingException as pe: raise ParsingException(sql, str(pe)) from None # Optimize the `LogicalPlan` or skip if configured if dask_config.get("sql.optimize"): try: rel = self.context.run_preoptimizer(nonOptimizedRel) rel = self.context.optimize_relational_algebra(rel) except DFOptimizationException as oe: # Use original plan and warn about inability to optimize plan rel = nonOptimizedRel logger.warning(str(oe)) else: rel = nonOptimizedRel rel_string = rel.explain_original() logger.debug(f"_get_ral -> LogicalPlan: {rel}") logger.debug(f"Extracted relational algebra:\n {rel_string}") return rel, rel_string def _compute_table_from_rel(self, rel: "LogicalPlan", return_futures: bool = True): dc = RelConverter.convert(rel, context=self) if rel.get_current_node_type() == "Explain": return dc if dc is None: return # Optimization might remove some alias projects. Make sure to keep them here. select_names = [field for field in rel.getRowType().getFieldList()] if select_names: cc = dc.column_container select_names = select_names[: len(cc.columns)] # Use FQ name if not unique and simple name if it is unique. If a join contains the same column # names the output col is prepended with the fully qualified column name field_counts = Counter([field.getName() for field in select_names]) select_names = [ field.getQualifiedName() if field_counts[field.getName()] > 1 else field.getName() for field in select_names ] cc = cc.rename( { df_col: select_name for df_col, select_name in zip(cc.columns, select_names) } ) dc = DataContainer(dc.df, cc) df = dc.assign() if not return_futures: df = df.compute() return df def _get_tables_from_stack(self): """Helper function to return all dask/pandas dataframes from the calling stack""" stack = inspect.stack() tables = {} # Traverse the stacks from inside to outside for frame_info in stack: for var_name, variable in frame_info.frame.f_locals.items(): if var_name.startswith("_"): continue if not dd.utils.is_dataframe_like(variable): continue # only set them if not defined in an inner context tables[var_name] = tables.get(var_name, variable) return tables def _register_callable( self, f: Any, name: str, aggregation: bool, parameters: list[tuple[str, type]], return_type: type, replace: bool = False, schema_name=None, row_udf: bool = False, ): """Helper function to do the function or aggregation registration""" schema_name = schema_name or self.schema_name schema = self.schema[schema_name] # validate and cache UDF metadata sql_parameters = [ (name, python_to_sql_type(param_type)) for name, param_type in parameters ] sql_return_type = python_to_sql_type(return_type) if not aggregation: f = UDF(f, row_udf, parameters, return_type) lower_name = name.lower() if lower_name in schema.functions: if replace: schema.function_lists = list( filter( lambda f: f.name.lower() != lower_name, schema.function_lists, ) ) del schema.functions[lower_name] elif schema.functions[lower_name] != f: raise ValueError( "Registering multiple functions with the same name is only permitted if replace=True" ) schema.function_lists.append( FunctionDescription( name.upper(), sql_parameters, sql_return_type, aggregation ) ) schema.function_lists.append( FunctionDescription( name.lower(), sql_parameters, sql_return_type, aggregation ) ) schema.functions[lower_name] = f ================================================ FILE: dask_sql/datacontainer.py ================================================ from collections import namedtuple from typing import Any, Union import dask.dataframe as dd import pandas as pd ColumnType = Union[str, int] FunctionDescription = namedtuple( "FunctionDescription", ["name", "parameters", "return_type", "aggregation"] ) class ColumnContainer: # Forward declaration pass class ColumnContainer: """ Helper class to store a list of columns, which do not necessarily be the ones of the dask dataframe. Instead, the container also stores a mapping from "frontend" columns (columns with the names and order expected by SQL) to "backend" columns (the real column names used by dask) to prevent unnecessary renames. """ def __init__( self, frontend_columns: list[str], frontend_backend_mapping: Union[dict[str, ColumnType], None] = None, ): assert all( isinstance(col, str) for col in frontend_columns ), "All frontend columns need to be of string type" self._frontend_columns = list(frontend_columns) if frontend_backend_mapping is None: self._frontend_backend_mapping = { col: col for col in self._frontend_columns } else: self._frontend_backend_mapping = frontend_backend_mapping def _copy(self) -> ColumnContainer: """ Internal function to copy this container """ return ColumnContainer( self._frontend_columns.copy(), self._frontend_backend_mapping.copy() ) def limit_to(self, fields: list[str]) -> ColumnContainer: """ Create a new ColumnContainer, which has frontend columns limited to only the ones given as parameter. Also uses the order of these as the new column order. """ if not fields: return self # pragma: no cover assert all(f in self._frontend_backend_mapping for f in fields) cc = self._copy() cc._frontend_columns = [str(x) for x in fields] return cc def rename(self, columns: dict[str, str]) -> ColumnContainer: """ Return a new ColumnContainer where the frontend columns are renamed according to the given mapping. Columns not present in the mapping are not touched, the order is preserved. """ cc = self._copy() for column_from, column_to in columns.items(): backend_column = self._frontend_backend_mapping[str(column_from)] cc._frontend_backend_mapping[str(column_to)] = backend_column cc._frontend_columns = [ str(columns[col]) if col in columns else col for col in self._frontend_columns ] return cc def rename_handle_duplicates( self, from_columns: list[str], to_columns: list[str] ) -> ColumnContainer: """ Same as `rename` but additionally handles presence of duplicates in `from_columns` """ cc = self._copy() cc._frontend_backend_mapping.update( { str(column_to): self._frontend_backend_mapping[str(column_from)] for column_from, column_to in zip(from_columns, to_columns) } ) columns = dict(zip(from_columns, to_columns)) cc._frontend_columns = [ str(columns.get(col, col)) for col in self._frontend_columns ] return cc def mapping(self) -> list[tuple[str, ColumnType]]: """ The mapping from frontend columns to backend columns. """ return list(self._frontend_backend_mapping.items()) @property def columns(self) -> list[str]: """ The stored frontend columns in the correct order """ return self._frontend_columns.copy() def add( self, frontend_column: str, backend_column: Union[str, None] = None ) -> ColumnContainer: """ Return a new ColumnContainer with the given column added. The column is added at the last position in the column list. """ cc = self._copy() frontend_column = str(frontend_column) cc._frontend_backend_mapping[frontend_column] = str( backend_column or frontend_column ) if frontend_column not in cc._frontend_columns: cc._frontend_columns.append(frontend_column) return cc def get_backend_by_frontend_index(self, index: int) -> str: """ Get back the dask column, which is referenced by the frontend (SQL) column with the given index. """ frontend_column = self._frontend_columns[index] backend_column = self._frontend_backend_mapping[frontend_column] return backend_column def get_backend_by_frontend_name(self, column: str) -> str: """ Get back the dask column, which is referenced by the frontend (SQL) column with the given name. """ try: return self._frontend_backend_mapping[column] except KeyError: return column def make_unique(self, prefix="col"): """ Make sure we have unique column names by calling each column _ where is the column index. """ return self.rename( columns={str(col): f"{prefix}_{i}" for i, col in enumerate(self.columns)} ) class Statistics: """ Statistics are used during the cost-based optimization. Currently, only the row count is supported, more properties might follow. It needs to be provided by the user. """ def __init__(self, row_count: int) -> None: self.row_count = row_count def __eq__(self, other): if isinstance(other, Statistics): return self.row_count == other.row_count return False class DataContainer: """ In SQL, every column operation or reference is done via the column index. Some dask operations, such as grouping, joining or concatenating preserve the columns in a different order than SQL would expect. However, we do not want to change the column data itself all the time (because this would lead to computational overhead), but still would like to keep the columns accessible by name and index. For this, we add an additional `ColumnContainer` to each dataframe, which does all the column mapping between "frontend" (what SQL expects, also in the correct order) and "backend" (what dask has). """ def __init__( self, df: dd.DataFrame, column_container: ColumnContainer, statistics: Statistics = None, filepath: str = None, ): self.df = df self.column_container = column_container self.statistics = statistics self.filepath = filepath def assign(self) -> dd.DataFrame: """ Combine the column mapping with the actual data and return a dataframe which has the the columns specified in the stored ColumnContainer. """ df = self.df[ [ self.column_container._frontend_backend_mapping[out_col] for out_col in self.column_container.columns ] ] df.columns = self.column_container.columns return df class UDF: def __init__(self, func, row_udf: bool, params, return_type=None): """ Helper class that handles different types of UDFs and manages how they should be mapped to dask operations. Two versions of UDFs are supported - when `row_udf=False`, the UDF is treated as expecting series-like objects as arguments and will simply run those through the function. When `row_udf=True` a row udf is expected and should be written to expect a dictlike object containing scalars """ self.row_udf = row_udf self.func = func self.names = [param[0] for param in params] self.meta = (None, return_type) def __call__(self, *args, **kwargs): if self.row_udf: column_args = [] scalar_args = [] for operand in args: if isinstance(operand, dd.Series): column_args.append(operand) else: scalar_args.append(operand) df = column_args[0].to_frame(self.names[0]) for name, col in zip(self.names[1:], column_args[1:]): df[name] = col result = df.apply( self.func, axis=1, args=tuple(scalar_args), meta=self.meta ).astype(self.meta[1]) else: result = self.func(*args, **kwargs) return result def __eq__(self, other): if isinstance(other, UDF): return self.func == other.func and self.row_udf == other.row_udf return NotImplemented def __hash__(self): return (self.func, self.row_udf).__hash__() class SchemaContainer: def __init__(self, name: str): self.__name__ = name self.tables: dict[str, DataContainer] = {} self.statistics: dict[str, Statistics] = {} self.experiments: dict[str, pd.DataFrame] = {} self.models: dict[str, tuple[Any, list[str]]] = {} self.functions: dict[str, UDF] = {} self.function_lists: list[FunctionDescription] = [] self.filepaths: dict[str, str] = {} ================================================ FILE: dask_sql/input_utils/__init__.py ================================================ from .convert import InputType, InputUtil from .dask import DaskInputPlugin from .hive import HiveInputPlugin from .intake import IntakeCatalogInputPlugin from .location import LocationInputPlugin from .pandaslike import PandasLikeInputPlugin from .sqlalchemy import SqlalchemyHiveInputPlugin __all__ = [ InputUtil, InputType, DaskInputPlugin, HiveInputPlugin, IntakeCatalogInputPlugin, LocationInputPlugin, PandasLikeInputPlugin, SqlalchemyHiveInputPlugin, ] ================================================ FILE: dask_sql/input_utils/base.py ================================================ from typing import Any class BaseInputPlugin: def is_correct_input( self, input_item: Any, table_name: str, format: str = None, **kwargs ): raise NotImplementedError def to_dc(self, input_item: Any, table_name: str, format: str = None, **kwargs): raise NotImplementedError ================================================ FILE: dask_sql/input_utils/convert.py ================================================ import logging from typing import TYPE_CHECKING, Union import dask.dataframe as dd import pandas as pd from dask_sql.datacontainer import ColumnContainer, DataContainer from dask_sql.input_utils.base import BaseInputPlugin from dask_sql.utils import Pluggable if TYPE_CHECKING: import cudf import hive import sqlalchemy logger = logging.Logger(__name__) InputType = Union[ dd.DataFrame, pd.DataFrame, str, Union[ "sqlalchemy.engine.base.Connection", "hive.Cursor", "cudf.core.dataframe.DataFrame", ], ] class InputUtil(Pluggable): """ Plugin list and helper class for transforming the inputs to create table into a dask dataframe """ @classmethod def add_plugin_class(cls, plugin_class: BaseInputPlugin, replace=True): """Convenience function to add a class directly to the plugins""" logger.debug(f"Registering Input plugin for {plugin_class}") cls.add_plugin(str(plugin_class), plugin_class(), replace=replace) @classmethod def to_dc( cls, input_item: InputType, table_name: str, format: str = None, persist: bool = True, gpu: bool = False, **kwargs, ) -> DataContainer: """ Turn possible input descriptions or formats (e.g. dask dataframes, pandas dataframes, locations as string, hive tables) into the loaded data containers, maybe persist them to cluster memory before. """ filled_get_dask_dataframe = lambda *args: cls._get_dask_dataframe( *args, table_name=table_name, format=format, gpu=gpu, **kwargs, ) if isinstance(input_item, list): table = dd.concat([filled_get_dask_dataframe(item) for item in input_item]) else: table = filled_get_dask_dataframe(input_item) if persist: table = table.persist() return DataContainer(table.copy(), ColumnContainer(table.columns)) @classmethod def _get_dask_dataframe( cls, input_item: InputType, table_name: str, format: str = None, gpu: bool = False, **kwargs, ): plugin_list = cls.get_plugins() for plugin in plugin_list: if plugin.is_correct_input( input_item, table_name=table_name, format=format, **kwargs ): return plugin.to_dc( input_item, table_name=table_name, format=format, gpu=gpu, **kwargs ) raise ValueError(f"Do not understand the input type {type(input_item)}") ================================================ FILE: dask_sql/input_utils/dask.py ================================================ from typing import Any import dask.dataframe as dd from dask_sql.input_utils.base import BaseInputPlugin class DaskInputPlugin(BaseInputPlugin): """Input Plugin for Dask DataFrames, just keeping them""" def is_correct_input( self, input_item: Any, table_name: str, format: str = None, **kwargs ): return isinstance(input_item, dd.DataFrame) or format == "dask" def to_dc( self, input_item: Any, table_name: str, format: str = None, gpu: bool = False, **kwargs ): if gpu: # pragma: no cover try: import dask_cudf # noqa: F401 except ImportError: raise ModuleNotFoundError( "Setting `gpu=True` for table creation requires dask_cudf" ) return input_item.to_backend("cudf", **kwargs) return input_item ================================================ FILE: dask_sql/input_utils/hive.py ================================================ import ast import logging import os from functools import partial from typing import Any, Union import dask.dataframe as dd from dask_sql._datafusion_lib import SqlTypeName try: from pyhive import hive except ImportError: # pragma: no cover hive = None try: import sqlalchemy except ImportError: # pragma: no cover sqlalchemy = None from dask_sql.input_utils.base import BaseInputPlugin from dask_sql.mappings import cast_column_type, sql_to_python_type logger = logging.Logger(__name__) class HiveInputPlugin(BaseInputPlugin): """Input Plugin from Hive""" def is_correct_input( self, input_item: Any, table_name: str, format: str = None, **kwargs ): is_hive_cursor = hive and isinstance(input_item, hive.Cursor) return self.is_sqlalchemy_hive(input_item) or is_hive_cursor or format == "hive" def is_sqlalchemy_hive(self, input_item: Any): return sqlalchemy and isinstance(input_item, sqlalchemy.engine.base.Connection) def to_dc( self, input_item: Any, table_name: str, format: str = None, gpu: bool = False, **kwargs, ): if gpu: # pragma: no cover raise Exception("Hive does not support gpu") table_name = kwargs.pop("hive_table_name", table_name) schema = kwargs.pop("hive_schema_name", "default") parsed = self._parse_hive_table_description(input_item, schema, table_name) ( column_information, table_information, storage_information, partition_information, ) = parsed logger.debug("Extracted hive information: ") logger.debug(f"column information: {column_information}") logger.debug(f"table information: {table_information}") logger.debug(f"storage information: {storage_information}") logger.debug(f"partition information: {partition_information}") # Convert column information column_information = { col: sql_to_python_type(SqlTypeName.fromString(col_type.upper())) for col, col_type in column_information.items() } # Extract format information if "InputFormat" in storage_information: format = storage_information["InputFormat"].split(".")[-1] # databricks format is different, see https://github.com/dask-contrib/dask-sql/issues/83 elif "InputFormat" in table_information: # pragma: no cover format = table_information["InputFormat"].split(".")[-1] else: # pragma: no cover raise RuntimeError( "Do not understand the output of 'DESCRIBE FORMATTED '" ) if ( format == "TextInputFormat" or format == "SequenceFileInputFormat" ): # pragma: no cover storage_description = storage_information.get("Storage Desc Params", {}) read_function = partial( dd.read_csv, sep=storage_description.get("field.delim", ","), header=None, ) elif format == "ParquetInputFormat" or format == "MapredParquetInputFormat": read_function = dd.read_parquet elif format == "OrcInputFormat": # pragma: no cover read_function = dd.read_orc elif format == "JsonInputFormat": # pragma: no cover read_function = dd.read_json else: # pragma: no cover raise AttributeError(f"Do not understand hive's table format {format}") def _normalize(loc): if loc.startswith("dbfs:/") and not loc.startswith( "dbfs://" ): # pragma: no cover # dask (or better: fsspec) needs to have the URL in a specific form # starting with two // after the protocol loc = f"dbfs://{loc.lstrip('dbfs:')}" # file:// is not a known protocol loc = loc.lstrip("file:") # Only allow files which do not start with . or _ # Especially, not allow the _SUCCESS files return os.path.join(loc, "[A-Za-z0-9-]*") def wrapped_read_function(location, column_information, **kwargs): location = _normalize(location) logger.debug(f"Reading in hive data from {location}") if format == "ParquetInputFormat" or format == "MapredParquetInputFormat": # Hack needed for parquet files. # If the folder structure is like .../col=3/... # parquet wants to read in the partition information. # However, we add the partition information by ourself # which will lead to problems afterwards # Therefore tell parquet to only read in the columns # we actually care right now kwargs.setdefault("columns", list(column_information.keys())) else: # pragma: no cover # prevent python to optimize it away and make coverage not respect the # pragma dummy = 0 # noqa: F841 df = read_function(location, **kwargs) logger.debug(f"Applying column information: {column_information}") df = df.rename(columns=dict(zip(df.columns, column_information.keys()))) for col, expected_type in column_information.items(): df = cast_column_type(df, col, expected_type) return df if partition_information: partition_list = self._parse_hive_partition_description( input_item, schema, table_name ) logger.debug(f"Reading in partitions from {partition_list}") tables = [] for partition in partition_list: parsed = self._parse_hive_table_description( input_item, schema, table_name, partition=partition ) ( partition_column_information, partition_table_information, _, _, ) = parsed location = partition_table_information["Location"] table = wrapped_read_function( location, partition_column_information, **kwargs ) # Now add the additional partition columns partition_values = ast.literal_eval( partition_table_information["Partition Value"] ) # multiple partition column values returned comma separated string if "," in partition_values: partition_values = [x.strip() for x in partition_values.split(",")] logger.debug( f"Applying additional partition information as columns: {partition_information}" ) partition_id = 0 for partition_key, partition_type in partition_information.items(): table[partition_key] = partition_values[partition_id] table = cast_column_type(table, partition_key, partition_type) partition_id += 1 tables.append(table) return dd.concat(tables) location = table_information["Location"] df = wrapped_read_function(location, column_information, **kwargs) return df def _parse_hive_table_description( self, cursor: Union["sqlalchemy.engine.base.Connection", "hive.Cursor"], schema: str, table_name: str, partition: str = None, ): """ Extract all information from the output of the DESCRIBE FORMATTED call, which is unfortunately in a format not easily readable by machines. """ cursor.execute( sqlalchemy.text(f"USE {schema}") if self.is_sqlalchemy_hive(cursor) else f"USE {schema}" ) if partition: # Hive wants quoted, comma separated list of partition keys partition = partition.replace("=", '="') partition = partition.replace("/", '",') + '"' result = self._fetch_all_results( cursor, f"DESCRIBE FORMATTED {table_name} PARTITION ({partition})" ) else: result = self._fetch_all_results(cursor, f"DESCRIBE FORMATTED {table_name}") logger.debug(f"Got information from hive: {result}") table_information = {} column_information = {} # using the fact that dicts are insertion ordered storage_information = {} partition_information = {} mode = "column" last_field = None for key, value, value2 in result: key = key.strip().rstrip(":") if key else "" value = value.strip() if value else "" value2 = value2.strip() if value2 else "" # That is just a comment line, we can skip it if key == "# col_name": continue if ( key == "# Detailed Table Information" or key == "# Detailed Partition Information" ): mode = "table" elif key == "# Storage Information": mode = "storage" elif key == "# Partition Information": mode = "partition" elif key.startswith("#"): mode = None # pragma: no cover elif key: if not value: value = dict() if mode == "column": column_information[key] = value last_field = column_information[key] elif mode == "storage": storage_information[key] = value last_field = storage_information[key] elif mode == "table": # Hive partition values come in a bracketed list # quoted partition values work regardless of partition column type if key == "Partition Value": value = '"' + value.strip("[]") + '"' table_information[key] = value last_field = table_information[key] elif mode == "partition": partition_information[key] = value last_field = partition_information[key] else: # pragma: no cover # prevent python to optimize it away and make coverage not respect the # pragma dummy = 0 # noqa: F841 elif value and last_field is not None: last_field[value] = value2 return ( column_information, table_information, storage_information, partition_information, ) def _parse_hive_partition_description( self, cursor: Union["sqlalchemy.engine.base.Connection", "hive.Cursor"], schema: str, table_name: str, ): """ Extract all partition informaton for a given table """ cursor.execute( sqlalchemy.text(f"USE {schema}") if self.is_sqlalchemy_hive(cursor) else f"USE {schema}" ) result = self._fetch_all_results(cursor, f"SHOW PARTITIONS {table_name}") return [row[0] for row in result] def _fetch_all_results( self, cursor: Union["sqlalchemy.engine.base.Connection", "hive.Cursor"], sql: str, ): """ The pyhive.Cursor and the sqlalchemy connection behave slightly different. The former has the fetchall method on the cursor, whereas the latter on the executed query. """ result = cursor.execute( sqlalchemy.text(sql) if self.is_sqlalchemy_hive(cursor) else sql ) try: return result.fetchall() except AttributeError: # pragma: no cover return cursor.fetchall() ================================================ FILE: dask_sql/input_utils/intake.py ================================================ from typing import Any try: import intake except ImportError: # pragma: no cover intake = None from dask_sql.input_utils.base import BaseInputPlugin class IntakeCatalogInputPlugin(BaseInputPlugin): """Input Plugin for Intake Catalogs, getting the table in dask format""" def is_correct_input( self, input_item: Any, table_name: str, format: str = None, **kwargs ): return intake and ( isinstance(input_item, intake.catalog.Catalog) or format == "intake" ) def to_dc( self, input_item: Any, table_name: str, format: str = None, gpu: bool = False, **kwargs, ): if gpu: # pragma: no cover raise NotImplementedError("Intake does not support gpu") table_name = kwargs.pop("intake_table_name", table_name) catalog_kwargs = kwargs.pop("catalog_kwargs", {}) if isinstance(input_item, str): input_item = intake.open_catalog(input_item, **catalog_kwargs) return input_item[table_name].to_dask(**kwargs) ================================================ FILE: dask_sql/input_utils/location.py ================================================ import os from typing import Any import dask.dataframe as dd from distributed.client import default_client from dask_sql.input_utils.base import BaseInputPlugin from dask_sql.input_utils.convert import InputUtil class LocationInputPlugin(BaseInputPlugin): """Input Plugin for everything, which can be read in from a file (on disk, remote etc.)""" def is_correct_input( self, input_item: Any, table_name: str, format: str = None, **kwargs ): return isinstance(input_item, str) def to_dc( self, input_item: Any, table_name: str, format: str = None, gpu: bool = False, **kwargs, ): if format == "memory": client = default_client() df = client.get_dataset(input_item, **kwargs) plugin_list = InputUtil.get_plugins() for plugin in plugin_list: if plugin.is_correct_input(df, table_name, format, **kwargs): return plugin.to_dc(df, table_name, format, gpu, **kwargs) if not format: _, extension = os.path.splitext(input_item) format = extension.lstrip(".") try: if gpu: # pragma: no cover try: import dask_cudf except ImportError: raise ModuleNotFoundError( "Setting `gpu=True` for table creation requires dask-cudf" ) read_function = getattr(dask_cudf, f"read_{format}") else: read_function = getattr(dd, f"read_{format}") except AttributeError: raise AttributeError(f"Can not read files of format {format}") return read_function(input_item, **kwargs) ================================================ FILE: dask_sql/input_utils/pandaslike.py ================================================ import dask.dataframe as dd import pandas as pd from dask_sql.input_utils.base import BaseInputPlugin class PandasLikeInputPlugin(BaseInputPlugin): """Input Plugin for Pandas Like DataFrames, which get converted to dask DataFrames""" def is_correct_input( self, input_item, table_name: str, format: str = None, **kwargs ): return ( dd.utils.is_dataframe_like(input_item) and not isinstance(input_item, dd.DataFrame) ) or format == "dask" def to_dc( self, input_item, table_name: str, format: str = None, gpu: bool = False, **kwargs, ): npartitions = kwargs.pop("npartitions", 1) if gpu: # pragma: no cover try: import cudf except ImportError: raise ModuleNotFoundError( "Setting `gpu=True` for table creation requires cudf" ) if isinstance(input_item, pd.DataFrame): input_item = cudf.from_pandas(input_item) return dd.from_pandas(input_item, npartitions=npartitions, **kwargs) ================================================ FILE: dask_sql/input_utils/sqlalchemy.py ================================================ from typing import Any from dask_sql.input_utils.hive import HiveInputPlugin class SqlalchemyHiveInputPlugin(HiveInputPlugin): """Input Plugin from sqlalchemy string""" def is_correct_input( self, input_item: Any, table_name: str, format: str = None, **kwargs ): correct_prefix = isinstance(input_item, str) and ( input_item.startswith("hive://") or input_item.startswith("databricks+pyhive://") ) return correct_prefix def to_dc( self, input_item: Any, table_name: str, format: str = None, gpu: bool = False, **kwargs ): # pragma: no cover if gpu: raise NotImplementedError("Hive does not support gpu") import sqlalchemy engine_kwargs = {} if "connect_args" in kwargs: engine_kwargs["connect_args"] = kwargs.pop("connect_args") if format is not None: raise AttributeError( "Format specified and sqlalchemy connection string set!" ) cursor = sqlalchemy.create_engine(input_item, **engine_kwargs).connect() return super().to_dc(cursor, table_name=table_name, **kwargs) ================================================ FILE: dask_sql/integrations/__init__.py ================================================ ================================================ FILE: dask_sql/integrations/fugue.py ================================================ try: import fugue import fugue_dask from dask.distributed import Client from fugue import WorkflowDataFrame, register_execution_engine from fugue_sql import FugueSQLWorkflow from triad import run_at_def from triad.utils.convert import get_caller_global_local_vars except ImportError: # pragma: no cover raise ImportError( "Can not load the fugue module. If you want to use this integration, you need to install it." ) from typing import Any, Optional import dask.dataframe as dd from dask_sql.context import Context @run_at_def def _register_engines() -> None: """Register (overwrite) the default Dask execution engine of Fugue. This function is invoked as an entrypoint, users don't need to call it explicitly. """ register_execution_engine( "dask", lambda conf, **kwargs: DaskSQLExecutionEngine(conf=conf), on_dup="overwrite", ) register_execution_engine( Client, lambda engine, conf, **kwargs: DaskSQLExecutionEngine( dask_client=engine, conf=conf ), on_dup="overwrite", ) class DaskSQLEngine(fugue.execution.execution_engine.SQLEngine): """ SQL engine for fugue which uses dask-sql instead of the native SQL implementation. Please note, that so far the native SQL engine in fugue understands a larger set of SQL commands, but in turns is (on average) slower in computation and scaling. """ def __init__(self, *args, **kwargs): """Create a new instance.""" super().__init__(*args, **kwargs) @property def is_distributed(self) -> bool: return True def select( self, dfs: fugue.dataframe.DataFrames, statement: str ) -> fugue.dataframe.DataFrame: """Send the SQL command to the dask-sql context and register all temporary dataframes""" c = Context() for k, v in dfs.items(): c.create_table(k, self.execution_engine.to_df(v).native) df = c.sql(statement) return fugue_dask.dataframe.DaskDataFrame(df) class DaskSQLExecutionEngine(fugue_dask.DaskExecutionEngine): """ Execution engine for fugue which has dask-sql as SQL engine configured. Please note, that so far the native SQL engine in fugue understands a larger set of SQL commands, but in turns is (on average) slower in computation and scaling. """ def __init__(self, *args, **kwargs): """Create a new instance.""" super().__init__(*args, **kwargs) self._default_sql_engine = DaskSQLEngine(self) @property def default_sql_engine(self) -> fugue.execution.execution_engine.SQLEngine: return self._default_sql_engine def fsql_dask( sql: str, ctx: Optional[Context] = None, register: bool = False, fugue_conf: Any = None, ) -> dict[str, dd.DataFrame]: """FugueSQL utility function that can consume Context directly. FugueSQL is a language extending standard SQL. It makes SQL eligible to describe end to end workflows. It also enables you to invoke python extensions in the SQL like language. For more, please read `FugueSQL Tutorial `_ Args: sql (:obj:`str`): Fugue SQL statement ctx (:class:`dask_sql.Context`): The context to operate on, defaults to None register (:obj:`bool`): Whether to register named steps back to the context (if provided), defaults to False fugue_conf (:obj:`Any`): a dictionary like object containing Fugue specific configs Example: .. code-block:: python # define a custom prepartition function for FugueSQL def median(df: pd.DataFrame) -> pd.DataFrame: df["y"] = df["y"].median() return df.head(1) # create a context with some tables c = Context() ... # run a FugueSQL query using the context as input query = ''' j = SELECT df1.*, df2.x FROM df1 INNER JOIN df2 ON df1.key = df2.key PERSIST TAKE 5 ROWS PREPARTITION BY x PRESORT key PRINT TRANSFORM j PREPARTITION BY x USING median PRINT ''' result = fsql_dask(query, c, register=True) assert "j" in result assert "j" in c.tables """ _global, _local = get_caller_global_local_vars() dag = FugueSQLWorkflow() dfs = ( {} if ctx is None else {k: dag.df(v.df) for k, v in ctx.schema[ctx.schema_name].tables.items()} ) result = dag._sql(sql, _global, _local, **dfs) dag.run(DaskSQLExecutionEngine(conf=fugue_conf)) result_dfs = { k: v.result.native for k, v in result.items() if isinstance(v, WorkflowDataFrame) } if register and ctx is not None: for k, v in result_dfs.items(): ctx.create_table(k, v) return result_dfs ================================================ FILE: dask_sql/integrations/ipython.py ================================================ import time from typing import TYPE_CHECKING from dask_sql.mappings import _SQL_TO_PYTHON_FRAMES from dask_sql.physical.rex.core import RexCallPlugin if TYPE_CHECKING: import dask_sql # That is definitely not pretty, but there seems to be no better way... KEYWORDS = [ "and", "as", "asc", "between", "by", "columns", "count", "create", "delete", "desc", "describe", "distinct", "exists", "from", "group", "having", "if", "in", "inner", "insert", "into", "is", "join", "left", "like", "model", "not", "on", "or", "order", "outer", "right", "schemas", "select", "set", "show", "table", "union", "where", ] def ipython_integration( context: "dask_sql.Context", auto_include: bool, disable_highlighting: bool, ) -> None: # pragma: no cover """Integrate the context with jupyter notebooks. Have a look into :ref:`Context.ipython_magic`.""" _register_ipython_magic(context, auto_include=auto_include) if not disable_highlighting: _register_syntax_highlighting() def _register_ipython_magic( c: "dask_sql.Context", auto_include: bool ) -> None: # pragma: no cover from IPython.core.magic import needs_local_scope, register_line_cell_magic @needs_local_scope def sql(line, cell, local_ns): if cell is None: # the magic function was called inline cell = line sql_statement = cell.format(**local_ns) dataframes = {} if auto_include: dataframes = c._get_tables_from_stack() t0 = time.time() res = c.sql(sql_statement, return_futures=False, dataframes=dataframes) if ( "CREATE OR REPLACE TABLE" in sql_statement or "CREATE OR REPLACE VIEW" in sql_statement ): table = sql_statement.split("CREATE OR REPLACE")[1] table = table.replace("TABLE", "").replace("VIEW", "").split()[0].strip() res = c.sql(f"SELECT * FROM {table}").tail() elif "CREATE TABLE" in sql_statement or "CREATE VIEW" in sql_statement: table = sql_statement.split("CREATE")[1] table = table.replace("TABLE", "").replace("VIEW", "").split()[0].strip() res = c.sql(f"SELECT * FROM {table}").tail() print(f"Execution time: {time.time() - t0:.2f}s") return res # Register a new magic function magic_func = register_line_cell_magic(sql) magic_func.MAGIC_NO_VAR_EXPAND_ATTR = True def _register_syntax_highlighting(): # pragma: no cover import json from IPython.core import display # JS snippet to use the created mime type highlighthing _JS_ENABLE_DASK_SQL = r""" require(['notebook/js/codecell'], function(codecell) { codecell.CodeCell.options_default.highlight_modes['magic_text/x-dasksql'] = {'reg':[/%%sql/]} ; Jupyter.notebook.events.on('kernel_ready.Kernel', function(){ Jupyter.notebook.get_cells().map(function(cell){ if (cell.cell_type == 'code'){ cell.auto_highlight(); } }) ; }); }); """ types = map(str, _SQL_TO_PYTHON_FRAMES.keys()) functions = list(RexCallPlugin.OPERATION_MAPPING.keys()) # Create a new mimetype mime_type = { "name": "sql", "keywords": _create_set(KEYWORDS + functions), "builtin": _create_set(types), "atoms": _create_set(["false", "true", "null"]), # "operatorChars": /^[*\/+\-%<>!=~&|^]/, "dateSQL": _create_set(["time"]), # More information # https://opensource.apple.com/source/WebInspectorUI/WebInspectorUI-7600.8.3/UserInterface/External/CodeMirror/sql.js.auto.html "support": _create_set(["ODBCdotTable", "doubleQuote", "zerolessFloat"]), } # Code original from fugue-sql, adjusted for dask-sql and using some more customizations js = ( r""" require(["codemirror/lib/codemirror"]); // We define a new mime type for syntax highlighting CodeMirror.defineMIME("text/x-dasksql", """ + json.dumps(mime_type) + r""" ); CodeMirror.modeInfo.push({ name: "Dask SQL", mime: "text/x-dasksql", mode: "sql" }); """ ) display.display_javascript(js + _JS_ENABLE_DASK_SQL, raw=True) def _create_set(keys: list[str]) -> dict[str, bool]: # pragma: no cover """Small helper function to turn a list into the correct format for codemirror""" return {key: True for key in keys} ================================================ FILE: dask_sql/mappings.py ================================================ import logging from datetime import datetime from typing import Any import dask.array as da import dask.config as dask_config import dask.dataframe as dd import numpy as np import pandas as pd from dask_sql._datafusion_lib import DaskTypeMap, SqlTypeName logger = logging.getLogger(__name__) # Default mapping between python types and SQL types _PYTHON_TO_SQL = { np.float64: SqlTypeName.DOUBLE, pd.Float64Dtype(): SqlTypeName.DOUBLE, float: SqlTypeName.FLOAT, np.float32: SqlTypeName.FLOAT, pd.Float32Dtype(): SqlTypeName.FLOAT, np.int64: SqlTypeName.BIGINT, pd.Int64Dtype(): SqlTypeName.BIGINT, int: SqlTypeName.INTEGER, np.int32: SqlTypeName.INTEGER, pd.Int32Dtype(): SqlTypeName.INTEGER, np.int16: SqlTypeName.SMALLINT, pd.Int16Dtype(): SqlTypeName.SMALLINT, np.int8: SqlTypeName.TINYINT, pd.Int8Dtype(): SqlTypeName.TINYINT, np.uint64: SqlTypeName.BIGINT, pd.UInt64Dtype(): SqlTypeName.BIGINT, np.uint32: SqlTypeName.INTEGER, pd.UInt32Dtype(): SqlTypeName.INTEGER, np.uint16: SqlTypeName.SMALLINT, pd.UInt16Dtype(): SqlTypeName.SMALLINT, np.uint8: SqlTypeName.TINYINT, pd.UInt8Dtype(): SqlTypeName.TINYINT, np.bool_: SqlTypeName.BOOLEAN, pd.BooleanDtype(): SqlTypeName.BOOLEAN, str: SqlTypeName.VARCHAR, np.object_: SqlTypeName.VARCHAR, pd.StringDtype(): SqlTypeName.VARCHAR, np.datetime64: SqlTypeName.TIMESTAMP, } # Default mapping between SQL types and python types # for values _SQL_TO_PYTHON_SCALARS = { "SqlTypeName.DOUBLE": np.float64, "SqlTypeName.FLOAT": np.float32, "SqlTypeName.DECIMAL": np.float32, "SqlTypeName.BIGINT": np.int64, "SqlTypeName.INTEGER": np.int32, "SqlTypeName.SMALLINT": np.int16, "SqlTypeName.TINYINT": np.int8, "SqlTypeName.BOOLEAN": np.bool_, "SqlTypeName.VARCHAR": str, "SqlTypeName.CHAR": str, "SqlTypeName.NULL": type(None), "SqlTypeName.SYMBOL": lambda x: x, # SYMBOL is a special type used for e.g. flags etc. We just keep it } # Default mapping between SQL types and python types # for data frames _SQL_TO_PYTHON_FRAMES = { "SqlTypeName.DOUBLE": np.float64, "SqlTypeName.FLOAT": np.float32, "SqlTypeName.DECIMAL": np.float64, # We use np.float64 always, even though we might be able to use a smaller type "SqlTypeName.BIGINT": pd.Int64Dtype(), "SqlTypeName.INTEGER": pd.Int32Dtype(), "SqlTypeName.SMALLINT": pd.Int16Dtype(), "SqlTypeName.TINYINT": pd.Int8Dtype(), "SqlTypeName.BOOLEAN": pd.BooleanDtype(), "SqlTypeName.VARCHAR": pd.StringDtype(), "SqlTypeName.CHAR": pd.StringDtype(), "SqlTypeName.DATE": np.dtype( " "DaskTypeMap": """Mapping between python and SQL types.""" if python_type in (int, float): python_type = np.dtype(python_type) elif python_type is str: python_type = np.dtype("object") if isinstance(python_type, np.dtype): python_type = python_type.type if isinstance(python_type, pd.DatetimeTZDtype): return DaskTypeMap( SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE, unit=str(python_type.unit), tz=str(python_type.tz), ) if is_decimal(python_type): return DaskTypeMap( SqlTypeName.DECIMAL, precision=python_type.precision, scale=python_type.scale, ) try: return DaskTypeMap(_PYTHON_TO_SQL[python_type]) except KeyError: # pragma: no cover raise NotImplementedError( f"The python type {python_type} is not implemented (yet)" ) def parse_datetime(obj): formats = [ "%Y-%m-%d %H:%M:%S", "%Y-%m-%d", "%d-%m-%Y %H:%M:%S", "%d-%m-%Y", "%m/%d/%Y %H:%M:%S", "%m/%d/%Y", ] for f in formats: try: datetime_obj = datetime.strptime(obj, f) return datetime_obj except ValueError: pass raise ValueError("Unable to parse datetime: " + obj) def sql_to_python_value(sql_type: "SqlTypeName", literal_value: Any) -> Any: """Mapping between SQL and python values (of correct type).""" # In most of the cases, we turn the value first into a string. # That might not be the most efficient thing to do, # but works for all types (so far) # Additionally, a literal type is not used # so often anyways. logger.debug( f"sql_to_python_value -> sql_type: {sql_type} literal_value: {literal_value}" ) if sql_type == SqlTypeName.CHAR or sql_type == SqlTypeName.VARCHAR: # Some varchars contain an additional encoding # in the format _ENCODING'string' literal_value = str(literal_value) if literal_value.startswith("_"): encoding, literal_value = literal_value.split("'", 1) literal_value = literal_value.rstrip("'") literal_value = literal_value.encode(encoding=encoding) return literal_value.decode(encoding=encoding) return literal_value elif ( sql_type == SqlTypeName.DECIMAL and dask_config.get("sql.mappings.decimal_support") == "cudf" ): from decimal import Decimal python_type = Decimal elif sql_type == SqlTypeName.INTERVAL_DAY: return np.timedelta64(literal_value[0], "D") + np.timedelta64( literal_value[1], "ms" ) elif sql_type == SqlTypeName.INTERVAL: # check for finer granular interval types, e.g., INTERVAL MONTH, INTERVAL YEAR try: interval_type = str(sql_type).split()[1].lower() if interval_type in {"year", "quarter", "month"}: # if sql_type is INTERVAL YEAR, Calcite will covert to months delta = pd.tseries.offsets.DateOffset(months=float(str(literal_value))) return delta except IndexError: # pragma: no cover # no finer granular interval type specified pass except TypeError: # pragma: no cover # interval type is not recognized, fall back to default case pass # Calcite will always convert INTERVAL types except YEAR, QUATER, MONTH to milliseconds # Issue: if sql_type is INTERVAL MICROSECOND, and value <= 1000, literal_value will be rounded to 0 return np.timedelta64(literal_value, "ms") elif sql_type == SqlTypeName.INTERVAL_MONTH_DAY_NANOSECOND: # DataFusion assumes 30 days per month. Therefore we multiply number of months by 30 and add to days return np.timedelta64( (literal_value[0] * 30) + literal_value[1], "D" ) + np.timedelta64(literal_value[2], "ns") elif sql_type == SqlTypeName.BOOLEAN: return bool(literal_value) elif ( sql_type == SqlTypeName.TIMESTAMP or sql_type == SqlTypeName.TIME or sql_type == SqlTypeName.DATE ): if isinstance(literal_value, str): literal_value = parse_datetime(literal_value) literal_value = np.datetime64(literal_value) elif str(literal_value) == "None": # NULL time return pd.NaT # pragma: no cover if sql_type == SqlTypeName.DATE: return literal_value.astype(" type: """Turn an SQL type into a dataframe dtype""" try: if ( sql_type == SqlTypeName.DECIMAL and dask_config.get("sql.mappings.decimal_support") == "cudf" ): try: import cudf except ImportError: raise ModuleNotFoundError( "Setting `sql.mappings.decimal_support=cudf` requires cudf" ) return cudf.Decimal128Dtype(*args) return _SQL_TO_PYTHON_FRAMES[str(sql_type)] except KeyError: # pragma: no cover raise NotImplementedError( f"The SQL type {str(sql_type)} is not implemented (yet)" ) def similar_type(lhs: type, rhs: type) -> bool: """ Measure simularity between types. Two types are similar, if they both come from the same family, e.g. both are ints, uints, floats, strings etc. Size or precision is not taken into account. TODO: nullability is not checked so far. """ pdt = pd.api.types is_uint = pdt.is_unsigned_integer_dtype is_sint = pdt.is_signed_integer_dtype is_float = pdt.is_float_dtype is_object = pdt.is_object_dtype is_string = pdt.is_string_dtype is_dt_ns = pdt.is_datetime64_ns_dtype is_dt_tz = lambda t: is_dt_ns(t) and isinstance(t, pd.DatetimeTZDtype) is_dt_ntz = lambda t: is_dt_ns(t) and not isinstance(t, pd.DatetimeTZDtype) is_td_ns = pdt.is_timedelta64_ns_dtype is_bool = pdt.is_bool_dtype checks = [ is_uint, is_sint, is_float, is_object, # is_string_dtype considers decimal columns to be string columns lambda x: is_string(x) and not is_decimal(x), is_dt_tz, is_dt_ntz, is_td_ns, is_bool, is_decimal, ] for check in checks: if check(lhs) and check(rhs): # check that decimal columns have equal precision/scale if check is is_decimal: return lhs.precision == rhs.precision and lhs.scale == rhs.scale return True return False def cast_column_type( df: dd.DataFrame, column_name: str, expected_type: type ) -> dd.DataFrame: """ Cast the type of the given column to the expected type, if they are far "enough" away. This means, a float will never be converted into a double or a tinyint into another int - but a string to an integer etc. """ current_type = df[column_name].dtype logger.debug( f"Column {column_name} has type {current_type}, expecting {expected_type}..." ) casted_column = cast_column_to_type(df[column_name], expected_type) if casted_column is not None: df[column_name] = casted_column return df def cast_column_to_type(col: dd.Series, expected_type: str): """Cast the given column to the expected type""" pdt = pd.api.types is_dt_ns = pdt.is_datetime64_ns_dtype is_dt_tz = lambda t: is_dt_ns(t) and isinstance(t, pd.DatetimeTZDtype) is_dt_ntz = lambda t: is_dt_ns(t) and not isinstance(t, pd.DatetimeTZDtype) current_type = col.dtype if similar_type(current_type, expected_type): logger.debug("...not converting.") return None if pdt.is_integer_dtype(expected_type): if pd.api.types.is_float_dtype(current_type): logger.debug("...truncating...") # Currently "trunc" can not be applied to NA (the pandas missing value type), # because NA is a different type. It works with np.NaN though. # For our use case, that does not matter, as the conversion to integer later # will convert both NA and np.NaN to NA. col = da.trunc(col.fillna(value=np.NaN)) elif pdt.is_timedelta64_dtype(current_type): logger.debug(f"Explicitly casting from {current_type} to np.int64") return col.astype(np.int64) if is_dt_tz(current_type) and is_dt_ntz(expected_type): # casting from timezone-aware to timezone-naive datatypes with astype is deprecated in pandas 2 return col.dt.tz_localize(None) logger.debug(f"Need to cast from {current_type} to {expected_type}") return col.astype(expected_type) def is_decimal(dtype): """ Check if dtype is a decimal type """ return "decimal" in str(dtype).lower() ================================================ FILE: dask_sql/physical/__init__.py ================================================ ================================================ FILE: dask_sql/physical/rel/__init__.py ================================================ from .convert import RelConverter ================================================ FILE: dask_sql/physical/rel/base.py ================================================ import logging from typing import TYPE_CHECKING, Optional import dask.dataframe as dd from dask_sql.datacontainer import ColumnContainer, DataContainer from dask_sql.mappings import cast_column_type, sql_to_python_type if TYPE_CHECKING: import dask_sql from dask_sql._datafusion_lib import LogicalPlan, RelDataType logger = logging.getLogger(__name__) class BaseRelPlugin: """ Base class for all plugins to convert between a RelNode to a python expression (dask dataframe). Derived classed needs to override the class_name attribute and the convert method. """ class_name = None def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> dd.DataFrame: """Base method to implement""" raise NotImplementedError @staticmethod def fix_column_to_row_type( cc: ColumnContainer, row_type: "RelDataType", join_type: Optional[str] = None ) -> ColumnContainer: """ Make sure that the given column container has the column names specified by the row type. We assume that the column order is already correct and will just "blindly" rename the columns. """ field_names = [str(x) for x in row_type.getFieldNames()] if join_type in ("leftsemi", "leftanti"): field_names = field_names[: len(cc.columns)] logger.debug(f"Renaming {cc.columns} to {field_names}") cc = cc.rename_handle_duplicates( from_columns=cc.columns, to_columns=field_names ) # TODO: We can also check for the types here and do any conversions if needed return cc.limit_to(field_names) @staticmethod def check_columns_from_row_type(df: dd.DataFrame, row_type: "RelDataType"): """ Similar to `self.fix_column_to_row_type`, but this time check for the correct column names instead of applying them. """ field_names = [str(x) for x in row_type.getFieldNames()] assert list(df.columns) == field_names # TODO: similar to self.fix_column_to_row_type, we should check for the types @staticmethod def assert_inputs( rel: "LogicalPlan", n: int = 1, context: "dask_sql.Context" = None, ) -> list[dd.DataFrame]: """ LogicalPlan nodes build on top of others. Those are called the "input" of the LogicalPlan. This function asserts that the given LogicalPlan has exactly as many input tables as expected and returns them already converted into a dask dataframe. """ input_rels = rel.get_inputs() assert len(input_rels) == n # Late import to remove cycling dependency from dask_sql.physical.rel.convert import RelConverter return [RelConverter.convert(input_rel, context) for input_rel in input_rels] @staticmethod def fix_dtype_to_row_type( dc: DataContainer, row_type: "RelDataType", join_type: Optional[str] = None ): """ Fix the dtype of the given data container (or: the df within it) to the data type given as argument. To prevent unneeded conversions, do only convert if really needed, e.g. if the two types are "similar" enough, do not convert. Similarity involves the same general type (int, float, string etc) but not necessary the size (int64 and int32 are compatible) or the nullability. TODO: we should check the nullability of the SQL type """ df = dc.df cc = dc.column_container field_list = row_type.getFieldList() if join_type in ("leftsemi", "leftanti"): field_list = field_list[: len(cc.columns)] field_types = { str(field.getQualifiedName()): field.getType() for field in field_list } for field_name, field_type in field_types.items(): sql_type = field_type.getSqlType() sql_type_args = tuple() if str(sql_type) == "SqlTypeName.DECIMAL": sql_type_args = field_type.getDataType().getPrecisionScale() expected_type = sql_to_python_type(sql_type, *sql_type_args) df_field_name = cc.get_backend_by_frontend_name(field_name) df = cast_column_type(df, df_field_name, expected_type) return DataContainer(df, dc.column_container) ================================================ FILE: dask_sql/physical/rel/convert.py ================================================ import logging from typing import TYPE_CHECKING import dask.dataframe as dd from dask_sql.physical.rel.base import BaseRelPlugin from dask_sql.utils import LoggableDataFrame, Pluggable if TYPE_CHECKING: import dask_sql from dask_sql._datafusion_lib import LogicalPlan logger = logging.getLogger(__name__) class RelConverter(Pluggable): """ Helper to convert from rel to a python expression This class stores plugins which can convert from RelNodes to python expression (typically dask dataframes). The stored plugins are assumed to have a class attribute "class_name" to control, which java classes they can convert and they are expected to have a convert (instance) method in the form def convert(self, rel, context) to do the actual conversion. """ @classmethod def add_plugin_class(cls, plugin_class: BaseRelPlugin, replace=True): """Convenience function to add a class directly to the plugins""" logger.debug(f"Registering REL plugin for {plugin_class.class_name}") cls.add_plugin(plugin_class.class_name, plugin_class(), replace=replace) @classmethod def convert(cls, rel: "LogicalPlan", context: "dask_sql.Context") -> dd.DataFrame: """ Convert SQL AST tree node(s) into a python expression (a dask dataframe) using the stored plugins and the dictionary of registered dask tables from the context. The SQL AST tree is traversed. The context of the traversal is saved in the Rust logic. We need to take that current node and determine what "type" of Relational operator it represents to build the execution chain. """ node_type = rel.get_current_node_type() try: plugin_instance = cls.get_plugin(node_type) except KeyError: # pragma: no cover raise NotImplementedError( f"No relational conversion for node type {node_type} available (yet)." ) logger.debug( f"Processing REL {rel} using {plugin_instance.__class__.__name__}..." ) df = plugin_instance.convert(rel, context=context) logger.debug(f"Processed REL {rel} into {LoggableDataFrame(df)}") return df ================================================ FILE: dask_sql/physical/rel/custom/__init__.py ================================================ from .alter import AlterSchemaPlugin, AlterTablePlugin from .analyze_table import AnalyzeTablePlugin from .create_catalog_schema import CreateCatalogSchemaPlugin from .create_experiment import CreateExperimentPlugin from .create_memory_table import CreateMemoryTablePlugin from .create_model import CreateModelPlugin from .create_table import CreateTablePlugin from .describe_model import DescribeModelPlugin from .distributeby import DistributeByPlugin from .drop_model import DropModelPlugin from .drop_schema import DropSchemaPlugin from .drop_table import DropTablePlugin from .export_model import ExportModelPlugin from .predict_model import PredictModelPlugin from .show_columns import ShowColumnsPlugin from .show_models import ShowModelsPlugin from .show_schemas import ShowSchemasPlugin from .show_tables import ShowTablesPlugin from .use_schema import UseSchemaPlugin __all__ = [ AnalyzeTablePlugin, CreateExperimentPlugin, CreateModelPlugin, CreateCatalogSchemaPlugin, CreateMemoryTablePlugin, CreateTablePlugin, DropModelPlugin, DropSchemaPlugin, DropTablePlugin, ExportModelPlugin, PredictModelPlugin, ShowColumnsPlugin, DescribeModelPlugin, ShowModelsPlugin, ShowSchemasPlugin, ShowTablesPlugin, UseSchemaPlugin, AlterSchemaPlugin, AlterTablePlugin, DistributeByPlugin, ] ================================================ FILE: dask_sql/physical/rel/custom/alter.py ================================================ import logging from typing import TYPE_CHECKING from dask_sql.physical.rel.base import BaseRelPlugin logger = logging.getLogger(__name__) if TYPE_CHECKING: import dask_sql from dask_sql._datafusion_lib import LogicalPlan class AlterSchemaPlugin(BaseRelPlugin): """ Alter schema name with new name; ALTER SCHEMA RENAME TO Using this SQL is equivalent to just doing context.alter_schema(,) but can also be used without writing a single line of code. Nothing is returned. """ class_name = "AlterSchema" def convert(self, rel: "LogicalPlan", context: "dask_sql.Context"): alter_schema = rel.alter_schema() old_schema_name = alter_schema.getOldSchemaName() new_schema_name = alter_schema.getNewSchemaName() logger.info( f"changing schema name from `{old_schema_name}` to `{new_schema_name}`" ) if old_schema_name not in context.schema: raise KeyError( f"Schema {old_schema_name} was not found, available schemas are - {context.schema.keys()}" ) context.alter_schema( old_schema_name=old_schema_name, new_schema_name=new_schema_name ) class AlterTablePlugin(BaseRelPlugin): """ Alter table name with new name; ALTER TABLE [IF EXISTS] RENAME TO Using this SQL is equivalent to just doing context.alter_table(,) but can also be used without writing a single line of code. Nothing is returned. """ class_name = "AlterTable" def convert(self, rel: "LogicalPlan", context: "dask_sql.Context"): alter_table = rel.alter_table() old_table_name = alter_table.getOldTableName() new_table_name = alter_table.getNewTableName() schema_name = alter_table.getSchemaName() or context.schema_name logger.info( f"changing table name from `{old_table_name}` to `{new_table_name}`" ) if old_table_name not in context.schema[schema_name].tables: if not alter_table.getIfExists(): raise KeyError( f"Table {old_table_name} was not found, available tables in {schema_name} are " f"- {context.schema[schema_name].tables.keys()}" ) else: return context.alter_table( old_table_name=old_table_name, new_table_name=new_table_name, schema_name=schema_name, ) ================================================ FILE: dask_sql/physical/rel/custom/analyze_table.py ================================================ from typing import TYPE_CHECKING import dask.dataframe as dd import pandas as pd from dask_sql.datacontainer import ColumnContainer, DataContainer from dask_sql.mappings import python_to_sql_type from dask_sql.physical.rel.base import BaseRelPlugin if TYPE_CHECKING: import dask_sql from dask_sql._datafusion_lib import LogicalPlan class AnalyzeTablePlugin(BaseRelPlugin): """ Show information on the table (like mean, max etc.) on all or a subset of the columns.. The SQL is: ANALYZE TABLE
COMPUTE STATISTICS FOR [ALL COLUMNS | COLUMNS a, b, ...] The result is also a table, although it is created on the fly. Please note: even though the syntax is very similar to e.g. [the spark version](https://spark.apache.org/docs/3.0.0/sql-ref-syntax-aux-analyze-table.html), this call does not help with query optimization (as the spark call would do), as this is currently not implemented in dask-sql. """ class_name = "AnalyzeTable" def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer: analyze_table = rel.analyze_table() schema_name = analyze_table.getSchemaName() or context.schema_name table_name = analyze_table.getTableName() dc = context.schema[schema_name].tables[table_name] columns = analyze_table.getColumns() if not columns: columns = dc.column_container.columns # Define some useful shortcuts mapping = dc.column_container.get_backend_by_frontend_name df = dc.df # Calculate statistics statistics = dd.concat( [ df[[mapping(col) for col in columns]].describe(), pd.DataFrame( { mapping(col): str( python_to_sql_type(df[mapping(col)].dtype) ).lower() for col in columns }, index=["data_type"], ), pd.DataFrame( {mapping(col): col for col in columns}, index=["col_name"] ), ] ) cc = ColumnContainer(statistics.columns) dc = DataContainer(statistics, cc) return dc ================================================ FILE: dask_sql/physical/rel/custom/create_catalog_schema.py ================================================ import logging from typing import TYPE_CHECKING from dask_sql.physical.rel.base import BaseRelPlugin if TYPE_CHECKING: import dask_sql from dask_sql._datafusion_lib import LogicalPlan logger = logging.getLogger(__name__) class CreateCatalogSchemaPlugin(BaseRelPlugin): """ Create a schema with the given name and register it at the context. The SQL call looks like CREATE SCHEMA Using this SQL is equivalent to just doing context.create_schema() but can also be used without writing a single line of code. Nothing is returned. """ class_name = "CreateCatalogSchema" def convert(self, rel: "LogicalPlan", context: "dask_sql.Context"): create_schema = rel.create_catalog_schema() schema_name = create_schema.getSchemaName() if schema_name in context.schema: if create_schema.getIfNotExists(): return elif not create_schema.getReplace(): raise RuntimeError( f"A Schema with the name {schema_name} is already present." ) context.create_schema(schema_name) ================================================ FILE: dask_sql/physical/rel/custom/create_experiment.py ================================================ import logging from typing import TYPE_CHECKING import dask.dataframe as dd import pandas as pd from dask_sql.datacontainer import ColumnContainer, DataContainer from dask_sql.physical.rel.base import BaseRelPlugin from dask_sql.physical.utils.ml_classes import get_cpu_classes, get_gpu_classes from dask_sql.utils import convert_sql_kwargs, import_class, is_cudf_type if TYPE_CHECKING: import dask_sql from dask_sql.rust import LogicalPlan logger = logging.getLogger(__name__) cpu_classes = get_cpu_classes() gpu_classes = get_gpu_classes() class CreateExperimentPlugin(BaseRelPlugin): """ Creates an Experiment for hyperparameter tuning or automl like behaviour, i.e evaluates models with different hyperparameters and registers the best performing model in the context with the name same as experiment name, which can be used for prediction sql syntax: CREATE EXPERIMENT WITH ( key = value ) AS OPTIONS: * model_class: Class name or full path to the class of the model to train. Any sklearn, cuML, XGBoost, or LightGBM classes can be inferred without the full path. In this case, models trained on cuDF dataframes are automatically mapped to cuML classes, and sklearn models otherwise. We map to cuML-Dask based models when possible and single-GPU cuML models otherwise. Any model class with sklearn interface is valid, but might or might not work well with Dask dataframes. You might need to install necessary packages to use the models. * experiment_class : Class name or full path of the Hyperparameter tuner. Any sklearn or cuML classes can be inferred without the full path. In this case, models trained on cuDF dataframes are automatically mapped to cuML classes, and sklearn models otherwise. * tune_parameters: Key-value of pairs of Hyperparameters to tune, i.e Search Space for particular model to tune * automl_class : Full path of the class which is sklearn compatible and able to distribute work to dask clusters, currently tested with tpot automl framework. Refer : [Tpot example](https://examples.dask.org/machine-learning/tpot.html) * target_column: Which column from the data to use as target. Currently this parameter is required field, because tuning and automl behaviour is implemented only for supervised algorithms. * automl_kwargs: Key-value pairs of arguments to be passed to automl class . Refer : [Using Tpot parameters](https://epistasislab.github.io/tpot/using/) * experiment_kwargs: Use this parameter for passing any keyword arguments to experiment class * tune_fit_kwargs: Use this parameter for passing any keyword arguments to experiment.fit() method example: for Hyperparameter tuning : (Train and evaluate same model with different parameters) CREATE EXPERIMENT my_exp WITH( model_class = 'sklearn.ensemble.GradientBoostingClassifier', experiment_class = 'sklearn.model_selection.GridSearchCV', tune_parameters = (n_estimators = ARRAY [16, 32, 2], learning_rate = ARRAY [0.1,0.01,0.001], max_depth = ARRAY [3,4,5,10] ), target_column = 'target' ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 ) for automl : (Train different different model with different parameter) CREATE EXPERIMENT my_exp WITH ( automl_class = 'tpot.TPOTClassifier', automl_kwargs = (population_size = 2 , generations=2, cv=2, n_jobs=-1, use_dask=True, max_eval_time_mins=1), target_column = 'target' ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 ) """ class_name = "CreateExperiment" def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer: create_experiment = rel.create_experiment() select = create_experiment.getSelectQuery() schema_name = create_experiment.getSchemaName() or context.schema_name experiment_name = create_experiment.getExperimentName() kwargs = convert_sql_kwargs(create_experiment.getSQLWithOptions()) if experiment_name in context.schema[schema_name].experiments: if create_experiment.getIfNotExists(): return elif not create_experiment.getOrReplace(): raise RuntimeError( f"A experiment with the name {experiment_name} is already present." ) logger.debug( f"Creating Experiment {experiment_name} from query {select} with options {kwargs}" ) model_class = None automl_class = None experiment_class = None if "model_class" in kwargs: model_class = kwargs.pop("model_class") # when model class was provided, must provide experiment_class also for tuning if "experiment_class" not in kwargs: raise ValueError( f"Parameters must include a 'experiment_class' parameter for tuning {model_class}." ) experiment_class = kwargs.pop("experiment_class") elif "automl_class" in kwargs: automl_class = kwargs.pop("automl_class") else: raise ValueError( "Parameters must include a 'model_class' or 'automl_class' parameter." ) target_column = kwargs.pop("target_column", "") tune_fit_kwargs = kwargs.pop("tune_fit_kwargs", {}) parameters = kwargs.pop("tune_parameters", {}) experiment_kwargs = kwargs.pop("experiment_kwargs", {}) automl_kwargs = kwargs.pop("automl_kwargs", {}) logger.info(parameters) training_df = context.sql(select) if not target_column: raise ValueError( "Unsupervised Algorithm cannot be tuned Automatically," "Consider providing 'target column'" ) non_target_columns = [ col for col in training_df.columns if col != target_column ] X = training_df[non_target_columns] y = training_df[target_column] if model_class and experiment_class: if is_cudf_type(training_df): model_class = gpu_classes.get(model_class, model_class) experiment_class = gpu_classes.get(experiment_class, experiment_class) else: model_class = cpu_classes.get(model_class, model_class) experiment_class = cpu_classes.get(experiment_class, experiment_class) try: ModelClass = import_class(model_class) except ImportError: raise ValueError( f"Can not import model {model_class}. Make sure you spelled it correctly and have installed all packages." ) try: ExperimentClass = import_class(experiment_class) except ImportError: raise ValueError( f"Can not import tuner {experiment_class}. Make sure you spelled it correctly and have installed all packages." ) from dask_sql.physical.rel.custom.wrappers import ParallelPostFit model = ModelClass() search = ExperimentClass(model, {**parameters}, **experiment_kwargs) logger.info(tune_fit_kwargs) search.fit( X.to_dask_array(lengths=True), y.to_dask_array(lengths=True), **tune_fit_kwargs, ) df = pd.DataFrame(search.cv_results_) df["model_class"] = model_class context.register_model( experiment_name, ParallelPostFit(estimator=search.best_estimator_), X.columns, schema_name=schema_name, ) if automl_class: try: AutoMLClass = import_class(automl_class) except ImportError: raise ValueError( f"Can not import automl model {automl_class}. Make sure you spelled it correctly and have installed all packages." ) from dask_sql.physical.rel.custom.wrappers import ParallelPostFit automl = AutoMLClass(**automl_kwargs) # should be avoided if data doesn't fit in memory automl.fit(X.compute(), y.compute()) df = ( pd.DataFrame(automl.evaluated_individuals_) .T.reset_index() .rename({"index": "models"}, axis=1) ) context.register_model( experiment_name, ParallelPostFit(estimator=automl.fitted_pipeline_), X.columns, schema_name=schema_name, ) context.register_experiment( experiment_name, experiment_results=df, schema_name=schema_name ) cc = ColumnContainer(df.columns) dc = DataContainer(dd.from_pandas(df, npartitions=1), cc) return dc ================================================ FILE: dask_sql/physical/rel/custom/create_memory_table.py ================================================ import logging from typing import TYPE_CHECKING from dask_sql.datacontainer import DataContainer from dask_sql.physical.rel.base import BaseRelPlugin if TYPE_CHECKING: import dask_sql from dask_sql._datafusion_lib import LogicalPlan logger = logging.getLogger(__name__) class CreateMemoryTablePlugin(BaseRelPlugin): """ Create a table or view from the given SELECT query and register it at the context. The SQL call looks like CREATE TABLE AS It sends the select query through the normal parsing and optimization and conversation before registering it. Using this SQL is equivalent to just doing df = context.sql("") X = df[everything except target_column] y = df[target_column] model = ModelClass(**kwargs) model = model.fit(X, y, **fit_kwargs) context.register_model(, model) but can also be used without writing a single line of code. Nothing is returned. Examples: CREATE MODEL my_model WITH ( model_class = 'xgboost.XGBClassifier', target_column = 'target' ) AS ( SELECT x, y, target FROM "data" ) Notes: This SQL call is not a 1:1 replacement for a normal python training and can not fulfill all use-cases or requirements! If you are dealing with large amounts of data, you might run into problems while model training and/or prediction, depending if your model can cope with dask dataframes. * if you are training on relatively small amounts of data but predicting on large data samples, you might want to set `wrap_predict` to True. With this option, model interference will be parallelized/distributed. * If you are training on large amounts of data, you can try setting wrap_fit to True. This will do the same on the training step, but works only on those models, which have a `fit_partial` method. """ class_name = "CreateModel" def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer: create_model = rel.create_model() select = create_model.getSelectQuery() schema_name = create_model.getSchemaName() or context.schema_name model_name = create_model.getModelName() kwargs = convert_sql_kwargs(create_model.getSQLWithOptions()) if model_name in context.schema[schema_name].models: if create_model.getIfNotExists(): return elif not create_model.getOrReplace(): raise RuntimeError( f"A model with the name {model_name} is already present." ) logger.debug( f"Creating model {model_name} from query {select} with options {kwargs}" ) try: model_class = kwargs.pop("model_class") except KeyError: raise ValueError("Parameters must include a 'model_class' parameter.") target_column = kwargs.pop("target_column", "") wrap_predict = kwargs.pop("wrap_predict", None) wrap_fit = kwargs.pop("wrap_fit", None) fit_kwargs = kwargs.pop("fit_kwargs", {}) if wrap_predict is False and "dask" not in model_class.lower(): warnings.warn( f"Consider using wrap_predict=True for non-Dask model {model_class}", RuntimeWarning, ) training_df = context.sql(select) if is_cudf_type(training_df): model_class = gpu_classes.get(model_class, model_class) else: model_class = cpu_classes.get(model_class, model_class) try: ModelClass = import_class(model_class) except ImportError: raise ImportError( f"Failed to import model {model_class}. Make sure it is spelled correctly and the relevant packages are installed." ) model = ModelClass(**kwargs) if wrap_predict is None: if ( "sklearn" in model_class or ("cuml" in model_class and "cuml.dask" not in model_class) or ("xgboost" in model_class and "xgboost.dask" not in model_class) ): wrap_predict = True else: wrap_predict = False if wrap_fit is None: if ( "sklearn" in model_class or ("cuml" in model_class and "cuml.dask" not in model_class) or ("xgboost" in model_class and "xgboost.dask" not in model_class) ) and hasattr(model, "partial_fit"): wrap_fit = True else: wrap_fit = False if target_column: non_target_columns = [ col for col in training_df.columns if col != target_column ] X = training_df[non_target_columns] y = training_df[target_column] else: X = training_df y = None if wrap_fit: from dask_sql.physical.rel.custom.wrappers import Incremental model = Incremental(estimator=model) if wrap_predict: from dask_sql.physical.rel.custom.wrappers import ParallelPostFit # When `wrap_predict` is set to True we train on single partition frames # because this is only useful for non dask distributed models # Training via delayed fit ensures that we dont have to transfer # data back to the client for training X_d = X.repartition(npartitions=1).to_delayed() if y is not None: y_d = y.repartition(npartitions=1).to_delayed() else: y_d = [None] delayed_model = [delayed(model.fit)(x_p, y_p) for x_p, y_p in zip(X_d, y_d)] model = delayed_model[0].compute() if "sklearn" in model_class: output_meta = np.array([]) model = ParallelPostFit( estimator=model, predict_meta=output_meta, predict_proba_meta=output_meta, transform_meta=output_meta, ) else: model = ParallelPostFit(estimator=model) else: model.fit(X, y, **fit_kwargs) context.register_model(model_name, model, X.columns, schema_name=schema_name) ================================================ FILE: dask_sql/physical/rel/custom/create_table.py ================================================ import logging from typing import TYPE_CHECKING from dask_sql.datacontainer import DataContainer from dask_sql.physical.rel.base import BaseRelPlugin from dask_sql.utils import convert_sql_kwargs if TYPE_CHECKING: import dask_sql from dask_sql._datafusion_lib import LogicalPlan logger = logging.getLogger(__name__) class CreateTablePlugin(BaseRelPlugin): """ Create a table with given parameters from already existing data and register it at the context. The SQL call looks like CREATE TABLE WITH ( parameter = value, ... ) It uses calls to "dask.dataframe.read_" where format is given by the "format" parameter (defaults to CSV). The only mandatory parameter is the "location" parameter. Using this SQL is equivalent to just doing df = dd.read_(location, **kwargs) context.register_dask_dataframe(df, ) but can also be used without writing a single line of code. Nothing is returned. """ class_name = "CreateTable" def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer: create_table = rel.create_table() schema_name = create_table.getSchemaName() or context.schema_name table_name = create_table.getTableName() if table_name in context.schema[schema_name].tables: if create_table.getIfNotExists(): return elif not create_table.getOrReplace(): raise RuntimeError( f"A table with the name {table_name} is already present." ) kwargs = convert_sql_kwargs(create_table.getSQLWithOptions()) logger.debug( f"Creating new table with name {table_name} and parameters {kwargs}" ) format = kwargs.pop("format", None) if format: # pragma: no cover format = format.lower() persist = kwargs.pop("persist", False) try: location = kwargs.pop("location") except KeyError: raise AttributeError("Parameters must include a 'location' parameter.") gpu = kwargs.pop("gpu", False) context.create_table( table_name, location, format=format, persist=persist, schema_name=schema_name, gpu=gpu, **kwargs, ) ================================================ FILE: dask_sql/physical/rel/custom/describe_model.py ================================================ from typing import TYPE_CHECKING import dask.dataframe as dd import pandas as pd from dask_sql.datacontainer import ColumnContainer, DataContainer from dask_sql.physical.rel.base import BaseRelPlugin if TYPE_CHECKING: import dask_sql from dask_sql._datafusion_lib import LogicalPlan class DescribeModelPlugin(BaseRelPlugin): """ Show all Params used to train a given model along with the columns used for training. The SQL is: DESCRIBE MODEL The result is also a table, although it is created on the fly. """ class_name = "DescribeModel" def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer: describe_model = rel.describe_model() schema_name = describe_model.getSchemaName() or context.schema_name model_name = describe_model.getModelName() if model_name not in context.schema[schema_name].models: raise RuntimeError(f"A model with the name {model_name} is not present.") model, training_columns = context.schema[schema_name].models[model_name] model_params = model.get_params() model_params["training_columns"] = training_columns.tolist() df = pd.DataFrame.from_dict(model_params, orient="index", columns=["Params"]) cc = ColumnContainer(df.columns) dc = DataContainer(dd.from_pandas(df, npartitions=1), cc) return dc ================================================ FILE: dask_sql/physical/rel/custom/distributeby.py ================================================ import logging from typing import TYPE_CHECKING from dask_sql.datacontainer import ColumnContainer, DataContainer from dask_sql.physical.rel.base import BaseRelPlugin from dask_sql.utils import LoggableDataFrame if TYPE_CHECKING: import dask_sql from dask_sql._datafusion_lib import LogicalPlan logger = logging.getLogger(__name__) class DistributeByPlugin(BaseRelPlugin): """ Distribute the target based on the specified sql identifier from a SELECT query. The SQL is: SELECT age, name FROM person DISTRIBUTE BY age """ # DataFusion provides the phrase `Repartition` in the LogicalPlan instead of `Distribute By`, it is the same thing class_name = "Repartition" def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer: distribute = rel.repartition_by() select = distribute.getSelectQuery() distribute_list = distribute.getDistributionColumns() df = context.sql(select) logger.debug(f"Extracted sub-dataframe as {LoggableDataFrame(df)}") logger.debug(f"Will now shuffle according to {distribute_list}") # Perform the distribute by operation via a Dask shuffle df = df.shuffle(distribute_list) cc = ColumnContainer(df.columns) dc = DataContainer(df, cc) return dc ================================================ FILE: dask_sql/physical/rel/custom/drop_model.py ================================================ import logging from typing import TYPE_CHECKING from dask_sql.datacontainer import DataContainer from dask_sql.physical.rel.base import BaseRelPlugin if TYPE_CHECKING: import dask_sql from dask_sql.rust import LogicalPlan logger = logging.getLogger(__name__) class DropModelPlugin(BaseRelPlugin): """ Drop a model with given name. The SQL call looks like DROP MODEL """ class_name = "DropModel" def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer: drop_model = rel.drop_model() schema_name = drop_model.getSchemaName() or context.schema_name model_name = drop_model.getModelName() if model_name not in context.schema[schema_name].models: if not drop_model.getIfExists(): raise RuntimeError( f"A model with the name {model_name} is not present." ) else: return del context.schema[schema_name].models[model_name] ================================================ FILE: dask_sql/physical/rel/custom/drop_schema.py ================================================ import logging from typing import TYPE_CHECKING from dask_sql.physical.rel.base import BaseRelPlugin if TYPE_CHECKING: import dask_sql from dask_sql._datafusion_lib import LogicalPlan logger = logging.getLogger(__name__) class DropSchemaPlugin(BaseRelPlugin): """ Drop a schema with given name. The SQL call looks like DROP SCHEMA """ class_name = "DropSchema" def convert(self, rel: "LogicalPlan", context: "dask_sql.Context"): drop_schema = rel.drop_schema() schema_name = drop_schema.getSchemaName() if schema_name not in context.schema: if not drop_schema.getIfExists(): raise RuntimeError( f"A SCHEMA with the name {schema_name} is not present." ) else: return context.drop_schema(schema_name) ================================================ FILE: dask_sql/physical/rel/custom/drop_table.py ================================================ import logging from typing import TYPE_CHECKING from dask_sql.datacontainer import DataContainer from dask_sql.physical.rel.base import BaseRelPlugin if TYPE_CHECKING: import dask_sql from dask_sql.rust import LogicalPlan logger = logging.getLogger(__name__) class DropTablePlugin(BaseRelPlugin): """ Drop a table with given name. The SQL call looks like DROP TABLE """ class_name = "DropTable" def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer: # Rust create_memory_table instance handle drop_table = rel.drop_table() qualified_table_name = drop_table.getQualifiedName() *schema_name, table_name = qualified_table_name.split(".") if len(schema_name) > 1: raise RuntimeError( f"Expected unqualified or fully qualified table name, got {qualified_table_name}." ) schema_name = context.schema_name if not schema_name else schema_name[0] if ( schema_name not in context.schema or table_name not in context.schema[schema_name].tables ): if not drop_table.getIfExists(): raise RuntimeError( f"A table with the name {qualified_table_name} is not present." ) else: return context.drop_table(table_name, schema_name=schema_name) ================================================ FILE: dask_sql/physical/rel/custom/export_model.py ================================================ import logging import pickle from typing import TYPE_CHECKING from dask_sql.physical.rel.base import BaseRelPlugin from dask_sql.utils import convert_sql_kwargs if TYPE_CHECKING: import dask_sql from dask_sql._datafusion_lib import LogicalPlan logger = logging.getLogger(__name__) class ExportModelPlugin(BaseRelPlugin): """ Export a trained model into a file using one of the supported model serialization libraries. Sql syntax: EXPORT MODEL WTIH ( format = "pickle", location = "model.pkl" ) 1. Most of the machine learning model framework support pickle as a serialization format for example: sklearn Pytorch 2. To export a universal (framework agnostic) model, use the mlflow (https://mlflow.org/) format - mlflow is a framework, which supports different flavors of model serialization, implemented for different ML libraries like xgboost,catboost,lightgbm etc. - A mlflow model is a self-contained artifact, which contains everything you need for loading the model - without import errors - To reproduce the environment, conda.yaml files are produced while saving the model and stored as part of the mlflow model NOTE: - Since dask-sql expects fit-predict style model (i.e sklearn compatible model), Only sklearn flavoured/sklearn subclassed models are supported as a part of mlflow serialization. i.e only mlflow sklearn flavour was used for all the sklearn compatible models. for example : instead of using xgb.core.Booster consider using xgboost.XGBClassifier since later is sklearn compatible """ class_name = "ExportModel" def convert(self, rel: "LogicalPlan", context: "dask_sql.Context"): export_model = rel.export_model() schema_name = export_model.getSchemaName() or context.schema_name model_name = export_model.getModelName() kwargs = convert_sql_kwargs(export_model.getSQLWithOptions()) format = kwargs.pop("format", "pickle").lower().strip() location = kwargs.pop("location", "tmp.pkl").strip() try: model, training_columns = context.schema[schema_name].models[model_name] except KeyError: raise RuntimeError(f"A model with the name {model_name} is not present.") logger.info( f"Using model serde has {format} and model will be exported to {location}" ) if format in ["pickle", "pkl"]: with open(location, "wb") as pkl_file: pickle.dump(model, pkl_file, **kwargs) elif format == "joblib": import joblib joblib.dump(model, location, **kwargs) elif format == "mlflow": try: import mlflow except ImportError: # pragma: no cover raise ImportError( "For export in the mlflow format, you need to have mlflow installed" ) try: import sklearn except ImportError: # pragma: no cover sklearn = None if sklearn is not None and isinstance(model, sklearn.base.BaseEstimator): mlflow.sklearn.save_model(model, location, **kwargs) else: raise NotImplementedError( "dask-sql supports only sklearn compatible model i.e fit-predict style model" ) elif format == "onnx": """ Need's Columns and their data type for converting any model format into Onnx format, and for every framework, need to install respective ONNX converters """ # TODO: Add support for Exporting model into ONNX format raise NotImplementedError("ONNX format currently not supported") ================================================ FILE: dask_sql/physical/rel/custom/metrics.py ================================================ # Copyright 2017, Dask developers # Dask-ML project - https://github.com/dask/dask-ml from typing import Optional, TypeVar import dask import dask.array as da import numpy as np import sklearn.metrics import sklearn.utils.multiclass from dask.array import Array from dask.utils import derived_from ArrayLike = TypeVar("ArrayLike", Array, np.ndarray) def accuracy_score( y_true: ArrayLike, y_pred: ArrayLike, normalize: bool = True, sample_weight: Optional[ArrayLike] = None, compute: bool = True, ) -> ArrayLike: """Accuracy classification score. In multilabel classification, this function computes subset accuracy: the set of labels predicted for a sample must *exactly* match the corresponding set of labels in y_true. Read more in the :ref:`User Guide `. Parameters ---------- y_true : 1d array-like, or label indicator array Ground truth (correct) labels. y_pred : 1d array-like, or label indicator array Predicted labels, as returned by a classifier. normalize : bool, optional (default=True) If ``False``, return the number of correctly classified samples. Otherwise, return the fraction of correctly classified samples. sample_weight : 1d array-like, optional Sample weights. .. versionadded:: 0.7.3 Returns ------- score : scalar dask Array If ``normalize == True``, return the correctly classified samples (float), else it returns the number of correctly classified samples (int). The best performance is 1 with ``normalize == True`` and the number of samples with ``normalize == False``. Notes ----- In binary and multiclass classification, this function is equal to the ``jaccard_similarity_score`` function. """ if y_true.ndim > 1: differing_labels = ((y_true - y_pred) == 0).all(1) score = differing_labels != 0 else: score = y_true == y_pred if normalize: score = da.average(score, weights=sample_weight) elif sample_weight is not None: score = da.dot(score, sample_weight) else: score = score.sum() if compute: score = score.compute() return score def _log_loss_inner( x: ArrayLike, y: ArrayLike, sample_weight: Optional[ArrayLike], **kwargs ): # da.map_blocks wasn't able to concatenate together the results # when we reduce down to a scalar per block. So we make an # array with 1 element. if sample_weight is not None: sample_weight = sample_weight.ravel() return np.array( [sklearn.metrics.log_loss(x, y, sample_weight=sample_weight, **kwargs)] ) def log_loss( y_true, y_pred, eps=1e-15, normalize=True, sample_weight=None, labels=None ): if not (dask.is_dask_collection(y_true) and dask.is_dask_collection(y_pred)): return sklearn.metrics.log_loss( y_true, y_pred, eps=eps, normalize=normalize, sample_weight=sample_weight, labels=labels, ) if y_pred.ndim > 1 and y_true.ndim == 1: y_true = y_true.reshape(-1, 1) drop_axis: Optional[int] = 1 if sample_weight is not None: sample_weight = sample_weight.reshape(-1, 1) else: drop_axis = None result = da.map_blocks( _log_loss_inner, y_true, y_pred, sample_weight, chunks=(1,), drop_axis=drop_axis, dtype="f8", eps=eps, normalize=normalize, labels=labels, ) if normalize and sample_weight is not None: sample_weight = sample_weight.ravel() block_weights = sample_weight.map_blocks(np.sum, chunks=(1,), keepdims=True) return da.average(result, 0, weights=block_weights) elif normalize: return result.mean() else: return result.sum() def _check_sample_weight(sample_weight: Optional[ArrayLike]): if sample_weight is not None: raise ValueError("'sample_weight' is not supported.") @derived_from(sklearn.metrics) def mean_squared_error( y_true: ArrayLike, y_pred: ArrayLike, sample_weight: Optional[ArrayLike] = None, multioutput: Optional[str] = "uniform_average", squared: bool = True, compute: bool = True, ) -> ArrayLike: _check_sample_weight(sample_weight) output_errors = ((y_pred - y_true) ** 2).mean(axis=0) if isinstance(multioutput, str) or multioutput is None: if multioutput == "raw_values": if compute: return output_errors.compute() else: return output_errors else: raise ValueError("Weighted 'multioutput' not supported.") result = output_errors.mean() if not squared: result = da.sqrt(result) if compute: result = result.compute() return result def _check_reg_targets( y_true: ArrayLike, y_pred: ArrayLike, multioutput: Optional[str] ): if multioutput is not None and multioutput != "uniform_average": raise NotImplementedError("'multioutput' must be 'uniform_average'") if y_true.ndim == 1: y_true = y_true.reshape((-1, 1)) if y_pred.ndim == 1: y_pred = y_pred.reshape((-1, 1)) # TODO: y_type, multioutput return None, y_true, y_pred, multioutput @derived_from(sklearn.metrics) def r2_score( y_true: ArrayLike, y_pred: ArrayLike, sample_weight: Optional[ArrayLike] = None, multioutput: Optional[str] = "uniform_average", compute: bool = True, ) -> ArrayLike: _check_sample_weight(sample_weight) _, y_true, y_pred, _ = _check_reg_targets(y_true, y_pred, multioutput) weight = 1.0 numerator = (weight * (y_true - y_pred) ** 2).sum(axis=0, dtype="f8") denominator = (weight * (y_true - y_true.mean(axis=0)) ** 2).sum(axis=0, dtype="f8") nonzero_denominator = denominator != 0 nonzero_numerator = numerator != 0 valid_score = nonzero_denominator & nonzero_numerator output_chunks = getattr(y_true, "chunks", [None, None])[1] output_scores = da.ones([y_true.shape[1]], chunks=output_chunks) with np.errstate(all="ignore"): output_scores[valid_score] = 1 - ( numerator[valid_score] / denominator[valid_score] ) output_scores[nonzero_numerator & ~nonzero_denominator] = 0.0 result = output_scores.mean(axis=0) if compute: result = result.compute() return result ================================================ FILE: dask_sql/physical/rel/custom/predict_model.py ================================================ import logging import uuid from typing import TYPE_CHECKING import dask.dataframe as dd import pandas as pd from dask_sql.datacontainer import ColumnContainer, DataContainer from dask_sql.physical.rel.base import BaseRelPlugin if TYPE_CHECKING: import dask_sql from dask_sql._datafusion_lib import LogicalPlan logger = logging.getLogger(__name__) class PredictModelPlugin(BaseRelPlugin): """ Predict the target using the given model and dataframe from the SELECT query. The SQL call looks like SELECT FROM PREDICT (MODEL , ) The return value is the input dataframe with an additional column named "target", which contains the predicted values. The model needs to be registered at the context before using it in this function, either by calling :ref:`register_model` explicitly or by training a model using the `CREATE MODEL` SQL statement. A model can be anything which has a `predict` function. Please note however, that it will need to act on Dask dataframes. If you are using a model not optimized for this, it might be that you run out of memory if your data is larger than the RAM of a single machine. To prevent this, have a look into the dask_sql.physical.rel.custom.wrappers.ParallelPostFit meta-estimator. If you are using a model trained with `CREATE MODEL` and the `wrap_predict` flag, this is done automatically. Using this SQL is roughly equivalent to doing df = context.sql("
The result is also a table, although it is created on the fly. """ class_name = "ShowColumns" def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer: show_columns = rel.show_columns() schema_name = show_columns.getSchemaName() or context.schema_name table_name = show_columns.getTableName() dc = context.schema[schema_name].tables[table_name] cols = dc.column_container.columns dtypes = list( map( lambda x: str(python_to_sql_type(x)).lower(), dc.df.dtypes, ) ) df = pd.DataFrame( { "Column": cols, "Type": dtypes, "Extra": [""] * len(cols), "Comment": [""] * len(cols), } ) cc = ColumnContainer(df.columns) dc = DataContainer(dd.from_pandas(df, npartitions=1), cc) return dc ================================================ FILE: dask_sql/physical/rel/custom/show_models.py ================================================ from typing import TYPE_CHECKING import dask.dataframe as dd import pandas as pd from dask_sql.datacontainer import ColumnContainer, DataContainer from dask_sql.physical.rel.base import BaseRelPlugin if TYPE_CHECKING: import dask_sql from dask_sql._datafusion_lib import LogicalPlan class ShowModelsPlugin(BaseRelPlugin): """ Show all MODELS currently registered/trained. The SQL is: SHOW MODELS The result is also a table, although it is created on the fly. """ class_name = "ShowModels" def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer: schema_name = rel.show_models().getSchemaName() or context.schema_name df = pd.DataFrame({"Models": list(context.schema[schema_name].models.keys())}) cc = ColumnContainer(df.columns) dc = DataContainer(dd.from_pandas(df, npartitions=1), cc) return dc ================================================ FILE: dask_sql/physical/rel/custom/show_schemas.py ================================================ from typing import TYPE_CHECKING import dask.dataframe as dd import pandas as pd from dask_sql.datacontainer import ColumnContainer, DataContainer from dask_sql.physical.rel.base import BaseRelPlugin if TYPE_CHECKING: import dask_sql from dask_sql._datafusion_lib import LogicalPlan class ShowSchemasPlugin(BaseRelPlugin): """ Show all schemas. The SQL is: SHOW SCHEMAS [FROM ] [LIKE <>] The result is also a table, although it is created on the fly. """ class_name = "ShowSchemas" def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer: show_schemas = rel.show_schemas() # "information_schema" is a schema which is found in every presto database schemas = list(context.schema.keys()) schemas.append("information_schema") df = pd.DataFrame({"Schema": schemas}) # currently catalogs other than the default `dask_sql` are not supported catalog_name = show_schemas.getCatalogName() or context.catalog_name if catalog_name != context.catalog_name: raise RuntimeError( f"A catalog with the name {catalog_name} is not present." ) # filter by LIKE value like = str(show_schemas.getLike()).strip("'") if like and like != "None": df = df[df.Schema == like] cc = ColumnContainer(df.columns) dc = DataContainer(dd.from_pandas(df, npartitions=1), cc) return dc ================================================ FILE: dask_sql/physical/rel/custom/show_tables.py ================================================ from typing import TYPE_CHECKING import dask.dataframe as dd import pandas as pd from dask_sql.datacontainer import ColumnContainer, DataContainer from dask_sql.physical.rel.base import BaseRelPlugin if TYPE_CHECKING: import dask_sql from dask_sql._datafusion_lib import LogicalPlan class ShowTablesPlugin(BaseRelPlugin): """ Show all tables currently defined for a given schema. The SQL is: SHOW TABLES FROM [.] Please note that dask-sql currently only allows for a single schema (called "schema"). The result is also a table, although it is created on the fly. """ class_name = "ShowTables" def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer: show_tables = rel.show_tables() # currently catalogs other than the default `dask_sql` are not supported catalog_name = show_tables.getCatalogName() or context.catalog_name if catalog_name != context.catalog_name: raise RuntimeError( f"A catalog with the name {catalog_name} is not present." ) schema_name = show_tables.getSchemaName() or context.schema_name if schema_name not in context.schema: raise AttributeError(f"Schema {schema_name} is not defined.") df = pd.DataFrame({"Table": list(context.schema[schema_name].tables.keys())}) cc = ColumnContainer(df.columns) dc = DataContainer(dd.from_pandas(df, npartitions=1), cc) return dc ================================================ FILE: dask_sql/physical/rel/custom/use_schema.py ================================================ from typing import TYPE_CHECKING from dask_sql.datacontainer import DataContainer from dask_sql.physical.rel.base import BaseRelPlugin if TYPE_CHECKING: import dask_sql from dask_sql._datafusion_lib import LogicalPlan class UseSchemaPlugin(BaseRelPlugin): """ Show all MODELS currently registered/trained. The SQL is: SHOW MODELS The result is also a table, although it is created on the fly. """ class_name = "UseSchema" def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer: schema_name = rel.use_schema().getSchemaName() if schema_name in context.schema: context.schema_name = schema_name # set the schema on the underlying DaskSQLContext as well context.context.use_schema(schema_name) else: raise RuntimeError(f"Schema {schema_name} not available") ================================================ FILE: dask_sql/physical/rel/custom/wrappers.py ================================================ # Copyright 2017, Dask developers # Dask-ML project - https://github.com/dask/dask-ml """Meta-estimators for parallelizing estimators using the scikit-learn API.""" import logging import warnings from typing import Any, Callable, Union import dask.array as da import dask.dataframe as dd import dask.delayed import numpy as np import sklearn.base import sklearn.metrics from dask.delayed import Delayed from dask.highlevelgraph import HighLevelGraph from sklearn.metrics import check_scoring as sklearn_check_scoring from sklearn.metrics import make_scorer from sklearn.utils.validation import check_is_fitted try: import sklearn.base import sklearn.metrics except ImportError: # pragma: no cover raise ImportError("sklearn must be installed") from dask_sql.physical.rel.custom.metrics import ( accuracy_score, log_loss, mean_squared_error, r2_score, ) logger = logging.getLogger(__name__) # Scorers accuracy_scorer: tuple[Any, Any] = (accuracy_score, {}) neg_mean_squared_error_scorer = (mean_squared_error, dict(greater_is_better=False)) r2_scorer: tuple[Any, Any] = (r2_score, {}) neg_log_loss_scorer = (log_loss, dict(greater_is_better=False, needs_proba=True)) SCORERS = dict( accuracy=accuracy_scorer, neg_mean_squared_error=neg_mean_squared_error_scorer, r2=r2_scorer, neg_log_loss=neg_log_loss_scorer, ) class ParallelPostFit(sklearn.base.BaseEstimator, sklearn.base.MetaEstimatorMixin): """Meta-estimator for parallel predict and transform. Parameters ---------- estimator : Estimator The underlying estimator that is fit. scoring : string or callable, optional A single string (see :ref:`scoring_parameter`) or a callable (see :ref:`scoring`) to evaluate the predictions on the test set. For evaluating multiple metrics, either give a list of (unique) strings or a dict with names as keys and callables as values. NOTE that when using custom scorers, each scorer should return a single value. Metric functions returning a list/array of values can be wrapped into multiple scorers that return one value each. See :ref:`multimetric_grid_search` for an example. .. warning:: If None, the estimator's default scorer (if available) is used. Most scikit-learn estimators will convert large Dask arrays to a single NumPy array, which may exhaust the memory of your worker. You probably want to always specify `scoring`. predict_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer) An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output type of the estimators ``predict`` call. This meta is necessary for for some estimators to work with ``dask.dataframe`` and ``dask.array`` predict_proba_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer) An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output type of the estimators ``predict_proba`` call. This meta is necessary for for some estimators to work with ``dask.dataframe`` and ``dask.array`` transform_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer) An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output type of the estimators ``transform`` call. This meta is necessary for for some estimators to work with ``dask.dataframe`` and ``dask.array`` """ class_name = "ParallelPostFit" def __init__( self, estimator=None, scoring=None, predict_meta=None, predict_proba_meta=None, transform_meta=None, ): self.estimator = estimator self.scoring = scoring self.predict_meta = predict_meta self.predict_proba_meta = predict_proba_meta self.transform_meta = transform_meta def _check_array(self, X): """Validate an array for post-fit tasks. Parameters ---------- X : Union[Array, DataFrame] Returns ------- same type as 'X' Notes ----- The following checks are applied. - Ensure that the array is blocked only along the samples. """ if isinstance(X, da.Array): if X.ndim == 2 and X.numblocks[1] > 1: logger.debug("auto-rechunking 'X'") if not np.isnan(X.chunks[0]).any(): X = X.rechunk({0: "auto", 1: -1}) else: X = X.rechunk({1: -1}) return X @property def _postfit_estimator(self): # The estimator instance to use for postfit tasks like score return self.estimator def fit(self, X, y=None, **kwargs): """Fit the underlying estimator. Parameters ---------- X, y : array-like **kwargs Additional fit-kwargs for the underlying estimator. Returns ------- self : object """ logger.info("Starting fit") result = self.estimator.fit(X, y, **kwargs) # Copy over learned attributes copy_learned_attributes(result, self) copy_learned_attributes(result, self.estimator) return self def partial_fit(self, X, y=None, **kwargs): logger.info("Starting partial_fit") result = self.estimator.partial_fit(X, y, **kwargs) # Copy over learned attributes copy_learned_attributes(result, self) copy_learned_attributes(result, self.estimator) return self def transform(self, X): """Transform block or partition-wise for dask inputs. For dask inputs, a dask array or dataframe is returned. For other inputs (NumPy array, pandas dataframe, scipy sparse matrix), the regular return value is returned. If the underlying estimator does not have a ``transform`` method, then an ``AttributeError`` is raised. Parameters ---------- X : array-like Returns ------- transformed : array-like """ self._check_method("transform") X = self._check_array(X) output_meta = self.transform_meta if isinstance(X, da.Array): if output_meta is None: output_meta = _get_output_dask_ar_meta_for_estimator( _transform, self._postfit_estimator, X, ) return X.map_blocks( _transform, estimator=self._postfit_estimator, meta=output_meta, ) elif isinstance(X, dd.DataFrame): if output_meta is None: output_meta = _transform(X._meta_nonempty, self._postfit_estimator) try: return X.map_partitions( _transform, self._postfit_estimator, output_meta, meta=output_meta, ) except ValueError: if output_meta is None: # dask-dataframe relies on dd.core.no_default # for infering meta output_meta = dd.core.no_default return X.map_partitions( _transform, estimator=self._postfit_estimator, meta=output_meta, ) else: return _transform(X, estimator=self._postfit_estimator) def score(self, X, y, compute=True): """Returns the score on the given data. Parameters ---------- X : array-like, shape = [n_samples, n_features] Input data, where n_samples is the number of samples and n_features is the number of features. y : array-like, shape = [n_samples] or [n_samples, n_output], optional Target relative to X for classification or regression; None for unsupervised learning. Returns ------- score : float return self.estimator.score(X, y) """ scoring = self.scoring X = self._check_array(X) y = self._check_array(y) if not scoring: if type(self._postfit_estimator).score == sklearn.base.RegressorMixin.score: scoring = "r2" elif ( type(self._postfit_estimator).score == sklearn.base.ClassifierMixin.score ): scoring = "accuracy" else: scoring = self.scoring if scoring: if not dask.is_dask_collection(X) and not dask.is_dask_collection(y): scorer = sklearn.metrics.get_scorer(scoring) else: scorer = get_scorer(scoring, compute=compute) return scorer(self, X, y) else: return self._postfit_estimator.score(X, y) def predict(self, X): """Predict for X. For dask inputs, a dask array or dataframe is returned. For other inputs (NumPy array, pandas dataframe, scipy sparse matrix), the regular return value is returned. Parameters ---------- X : array-like Returns ------- y : array-like """ self._check_method("predict") X = self._check_array(X) output_meta = self.predict_meta if isinstance(X, da.Array): if output_meta is None: output_meta = _get_output_dask_ar_meta_for_estimator( _predict, self._postfit_estimator, X ) result = X.map_blocks( _predict, estimator=self._postfit_estimator, drop_axis=1, meta=output_meta, ) return result elif isinstance(X, dd.DataFrame): if output_meta is None: # dask-dataframe relies on dd.core.no_default # for infering meta output_meta = _predict(X._meta_nonempty, self._postfit_estimator) try: return X.map_partitions( _predict, self._postfit_estimator, output_meta, meta=output_meta, ) except ValueError: if output_meta is None: output_meta = dd.core.no_default return X.map_partitions( _predict, estimator=self._postfit_estimator, meta=output_meta, ) else: return _predict(X, estimator=self._postfit_estimator) def predict_proba(self, X): """Probability estimates. For dask inputs, a dask array or dataframe is returned. For other inputs (NumPy array, pandas dataframe, scipy sparse matrix), the regular return value is returned. If the underlying estimator does not have a ``predict_proba`` method, then an ``AttributeError`` is raised. Parameters ---------- X : array or dataframe Returns ------- y : array-like """ X = self._check_array(X) self._check_method("predict_proba") output_meta = self.predict_proba_meta if isinstance(X, da.Array): if output_meta is None: output_meta = _get_output_dask_ar_meta_for_estimator( _predict_proba, self._postfit_estimator, X ) # XXX: multiclass return X.map_blocks( _predict_proba, estimator=self._postfit_estimator, meta=output_meta, chunks=(X.chunks[0], len(self._postfit_estimator.classes_)), ) elif isinstance(X, dd.DataFrame): if output_meta is None: # dask-dataframe relies on dd.core.no_default # for infering meta output_meta = _predict_proba(X._meta_nonempty, self._postfit_estimator) try: return X.map_partitions( _predict_proba, self._postfit_estimator, output_meta, meta=output_meta, ) except ValueError: if output_meta is None: output_meta = dd.core.no_default return X.map_partitions( _predict_proba, estimator=self._postfit_estimator, meta=output_meta ) else: return _predict_proba(X, estimator=self._postfit_estimator) def predict_log_proba(self, X): """Log of probability estimates. For dask inputs, a dask array or dataframe is returned. For other inputs (NumPy array, pandas dataframe, scipy sparse matrix), the regular return value is returned. If the underlying estimator does not have a ``predict_proba`` method, then an ``AttributeError`` is raised. Parameters ---------- X : array or dataframe Returns ------- y : array-like """ self._check_method("predict_log_proba") return da.log(self.predict_proba(X)) def _check_method(self, method): """Check if self.estimator has 'method'. Raises ------ AttributeError """ estimator = self._postfit_estimator if not hasattr(estimator, method): msg = "The wrapped estimator '{}' does not have a '{}' method.".format( estimator, method ) raise AttributeError(msg) return getattr(estimator, method) class Incremental(ParallelPostFit): """Metaestimator for feeding Dask Arrays to an estimator blockwise. This wrapper provides a bridge between Dask objects and estimators implementing the ``partial_fit`` API. These *incremental learners* can train on batches of data. This fits well with Dask's blocked data structures. .. note:: This meta-estimator is not appropriate for hyperparameter optimization on larger-than-memory datasets. See the `list of incremental learners`_ in the scikit-learn documentation for a list of estimators that implement the ``partial_fit`` API. Note that `Incremental` is not limited to just these classes, it will work on any estimator implementing ``partial_fit``, including those defined outside of scikit-learn itself. Calling :meth:`Incremental.fit` with a Dask Array will pass each block of the Dask array or arrays to ``estimator.partial_fit`` *sequentially*. Like :class:`ParallelPostFit`, the methods available after fitting (e.g. :meth:`Incremental.predict`, etc.) are all parallel and delayed. The ``estimator_`` attribute is a clone of `estimator` that was actually used during the call to ``fit``. All attributes learned during training are available on ``Incremental`` directly. .. _list of incremental learners: https://scikit-learn.org/stable/modules/computing.html#incremental-learning # noqa Parameters ---------- estimator : Estimator Any object supporting the scikit-learn ``partial_fit`` API. scoring : string or callable, optional A single string (see :ref:`scoring_parameter`) or a callable (see :ref:`scoring`) to evaluate the predictions on the test set. For evaluating multiple metrics, either give a list of (unique) strings or a dict with names as keys and callables as values. NOTE that when using custom scorers, each scorer should return a single value. Metric functions returning a list/array of values can be wrapped into multiple scorers that return one value each. See :ref:`multimetric_grid_search` for an example. .. warning:: If None, the estimator's default scorer (if available) is used. Most scikit-learn estimators will convert large Dask arrays to a single NumPy array, which may exhaust the memory of your worker. You probably want to always specify `scoring`. random_state : int or numpy.random.RandomState, optional Random object that determines how to shuffle blocks. shuffle_blocks : bool, default True Determines whether to call ``partial_fit`` on a randomly selected chunk of the Dask arrays (default), or to fit in sequential order. This does not control shuffle between blocks or shuffling each block. predict_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer) An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output type of the estimators ``predict`` call. This meta is necessary for for some estimators to work with ``dask.dataframe`` and ``dask.array`` predict_proba_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer) An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output type of the estimators ``predict_proba`` call. This meta is necessary for for some estimators to work with ``dask.dataframe`` and ``dask.array`` transform_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer) An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output type of the estimators ``transform`` call. This meta is necessary for for some estimators to work with ``dask.dataframe`` and ``dask.array`` Attributes ---------- estimator_ : Estimator A clone of `estimator` that was actually fit during the ``.fit`` call. """ def __init__( self, estimator=None, scoring=None, shuffle_blocks=True, random_state=None, assume_equal_chunks=True, predict_meta=None, predict_proba_meta=None, transform_meta=None, ): self.shuffle_blocks = shuffle_blocks self.random_state = random_state self.assume_equal_chunks = assume_equal_chunks super().__init__( estimator=estimator, scoring=scoring, predict_meta=predict_meta, predict_proba_meta=predict_proba_meta, transform_meta=transform_meta, ) @property def _postfit_estimator(self): check_is_fitted(self, "estimator_") return self.estimator_ def _fit_for_estimator(self, estimator, X, y, **fit_kwargs): check_scoring(estimator, self.scoring) if not dask.is_dask_collection(X) and not dask.is_dask_collection(y): try: result = estimator.partial_fit(X=X, y=y, **fit_kwargs) except ValueError: result = estimator.partial_fit( X=X, y=y, classes=np.unique(y), **fit_kwargs ) else: result = fit( estimator, X, y, random_state=self.random_state, shuffle_blocks=self.shuffle_blocks, assume_equal_chunks=self.assume_equal_chunks, **fit_kwargs, ) copy_learned_attributes(result, self) self.estimator_ = result return self def fit(self, X, y=None, **fit_kwargs): estimator = sklearn.base.clone(self.estimator) self._fit_for_estimator(estimator, X, y, **fit_kwargs) return self def partial_fit(self, X, y=None, **fit_kwargs): """Fit the underlying estimator. If this estimator has not been previously fit, this is identical to :meth:`Incremental.fit`. If it has been previously fit, ``self.estimator_`` is used as the starting point. Parameters ---------- X, y : array-like **kwargs Additional fit-kwargs for the underlying estimator. Returns ------- self : object """ estimator = getattr(self, "estimator_", None) if estimator is None: estimator = sklearn.base.clone(self.estimator) return self._fit_for_estimator(estimator, X, y, **fit_kwargs) def handle_empty_partitions(output_meta): if hasattr(output_meta, "__array_function__"): if len(output_meta.shape) == 1: shape = 0 else: shape = list(output_meta.shape) shape[0] = 0 ar = np.zeros( shape=shape, dtype=output_meta.dtype, like=output_meta, ) return ar elif "scipy.sparse" in type(output_meta).__module__: # sparse matrices don't support # `like` due to non implemented __array_function__ # Refer https://github.com/scipy/scipy/issues/10362 # Note below works for both cupy and scipy sparse matrices if len(output_meta.shape) == 1: shape = 0 else: shape = list(output_meta.shape) shape[0] = 0 ar = type(output_meta)(shape, dtype=output_meta.dtype) return ar elif hasattr(output_meta, "iloc"): return output_meta.iloc[:0, :] def _predict(part, estimator, output_meta=None): if part.shape[0] == 0 and output_meta is not None: empty_output = handle_empty_partitions(output_meta) if empty_output is not None: return empty_output return estimator.predict(part) def _predict_proba(part, estimator, output_meta=None): if part.shape[0] == 0 and output_meta is not None: empty_output = handle_empty_partitions(output_meta) if empty_output is not None: return empty_output return estimator.predict_proba(part) def _transform(part, estimator, output_meta=None): if part.shape[0] == 0 and output_meta is not None: empty_output = handle_empty_partitions(output_meta) if empty_output is not None: return empty_output return estimator.transform(part) def _get_output_dask_ar_meta_for_estimator(model_fn, estimator, input_dask_ar): """ Returns the output metadata array for the model function (predict, transform etc) by running the appropriate function on dummy data of shape (1, n_features) Parameters ---------- model_fun: Model function _predict, _transform etc estimator : Estimator The underlying estimator that is fit. input_dask_ar: The input dask_array Returns ------- metadata: metadata of output dask array """ # sklearn fails if input array has size size # It requires at least 1 sample to run successfully input_meta = input_dask_ar._meta if hasattr(input_meta, "__array_function__"): ar = np.zeros( shape=(1, input_dask_ar.shape[1]), dtype=input_dask_ar.dtype, like=input_meta, ) elif "scipy.sparse" in type(input_meta).__module__: # sparse matrices dont support # `like` due to non implimented __array_function__ # Refer https://github.com/scipy/scipy/issues/10362 # Note below works for both cupy and scipy sparse matrices ar = type(input_meta)((1, input_dask_ar.shape[1]), dtype=input_dask_ar.dtype) else: func_name = model_fn.__name__.strip("_") msg = ( f"Metadata for {func_name} is not provided, so Dask is " f"running the {func_name} " "function on a small dataset to guess output metadata. " "As a result, It is possible that Dask will guess incorrectly." ) warnings.warn(msg) ar = np.zeros(shape=(1, input_dask_ar.shape[1]), dtype=input_dask_ar.dtype) return model_fn(ar, estimator) def copy_learned_attributes(from_estimator, to_estimator): attrs = {k: v for k, v in vars(from_estimator).items() if k.endswith("_")} for k, v in attrs.items(): setattr(to_estimator, k, v) def get_scorer(scoring: Union[str, Callable], compute: bool = True) -> Callable: """Get a scorer from string Parameters ---------- scoring : str | callable scoring method as string. If callable it is returned as is. Returns ------- scorer : callable The scorer. """ # This is the same as sklearns, only we use our SCORERS dict, # and don't have back-compat code if isinstance(scoring, str): try: scorer, kwargs = SCORERS[scoring] except KeyError: raise ValueError( "{} is not a valid scoring value. " "Valid options are {}".format(scoring, sorted(SCORERS)) ) else: scorer = scoring kwargs = {} kwargs["compute"] = compute return make_scorer(scorer, **kwargs) def check_scoring(estimator, scoring=None, **kwargs): res = sklearn_check_scoring(estimator, scoring=scoring, **kwargs) if scoring in SCORERS.keys(): func, kwargs = SCORERS[scoring] return make_scorer(func, **kwargs) return res def fit( model, x, y, compute=True, shuffle_blocks=True, random_state=None, assume_equal_chunks=False, **kwargs, ): """Fit scikit learn model against dask arrays Model must support the ``partial_fit`` interface for online or batch learning. Ideally your rows are independent and identically distributed. By default, this function will step through chunks of the arrays in random order. Parameters ---------- model: sklearn model Any model supporting partial_fit interface x: dask Array Two dimensional array, likely tall and skinny y: dask Array One dimensional array with same chunks as x's rows compute : bool Whether to compute this result shuffle_blocks : bool Whether to shuffle the blocks with ``random_state`` or not random_state : int or numpy.random.RandomState Random state to use when shuffling blocks kwargs: options to pass to partial_fit """ nblocks, x_name = _blocks_and_name(x) if y is not None: y_nblocks, y_name = _blocks_and_name(y) assert y_nblocks == nblocks else: y_name = "" if not hasattr(model, "partial_fit"): msg = "The class '{}' does not implement 'partial_fit'." raise ValueError(msg.format(type(model))) order = list(range(nblocks)) if shuffle_blocks: rng = sklearn.utils.check_random_state(random_state) rng.shuffle(order) name = "fit-" + dask.base.tokenize(model, x, y, kwargs, order) if hasattr(x, "chunks") and x.ndim > 1: x_extra = (0,) else: x_extra = () dsk = {(name, -1): model} dsk.update( { (name, i): ( _partial_fit, (name, i - 1), (x_name, order[i]) + x_extra, (y_name, order[i]), kwargs, ) for i in range(nblocks) } ) dependencies = [x] if y is not None: dependencies.append(y) new_dsk = HighLevelGraph.from_collections(name, dsk, dependencies=dependencies) value = Delayed((name, nblocks - 1), new_dsk, layer=name) if compute: return value.compute() else: return value def _blocks_and_name(obj): if hasattr(obj, "chunks"): nblocks = len(obj.chunks[0]) name = obj.name elif hasattr(obj, "npartitions"): # dataframe, bag nblocks = obj.npartitions if hasattr(obj, "_name"): # dataframe name = obj._name else: # bag name = obj.name return nblocks, name def _partial_fit(model, x, y, kwargs=None): kwargs = kwargs or dict() model.partial_fit(x, y, **kwargs) return model ================================================ FILE: dask_sql/physical/rel/logical/__init__.py ================================================ from .aggregate import DaskAggregatePlugin from .cross_join import DaskCrossJoinPlugin from .empty import DaskEmptyRelationPlugin from .explain import ExplainPlugin from .filter import DaskFilterPlugin from .join import DaskJoinPlugin from .limit import DaskLimitPlugin from .project import DaskProjectPlugin from .sample import SamplePlugin from .sort import DaskSortPlugin from .subquery_alias import SubqueryAlias from .table_scan import DaskTableScanPlugin from .union import DaskUnionPlugin from .values import DaskValuesPlugin from .window import DaskWindowPlugin __all__ = [ DaskAggregatePlugin, DaskEmptyRelationPlugin, DaskFilterPlugin, DaskJoinPlugin, DaskCrossJoinPlugin, DaskLimitPlugin, DaskProjectPlugin, DaskSortPlugin, DaskTableScanPlugin, DaskUnionPlugin, DaskValuesPlugin, DaskWindowPlugin, SamplePlugin, ExplainPlugin, SubqueryAlias, ] ================================================ FILE: dask_sql/physical/rel/logical/aggregate.py ================================================ import logging import operator from collections import defaultdict from functools import reduce from typing import TYPE_CHECKING, Any, Callable import dask.dataframe as dd import pandas as pd from dask import config as dask_config from dask_sql.datacontainer import ColumnContainer, DataContainer from dask_sql.physical.rel.base import BaseRelPlugin from dask_sql.physical.rex.convert import RexConverter from dask_sql.physical.rex.core.call import IsNullOperation from dask_sql.utils import is_cudf_type, new_temporary_column if TYPE_CHECKING: import dask_sql from dask_sql._datafusion_lib import LogicalPlan logger = logging.getLogger(__name__) class ReduceAggregation(dd.Aggregation): """ A special form of an aggregation, that applies a given operation on all elements in a group with "reduce". """ def __init__(self, name: str, operation: Callable): series_aggregate = lambda s: s.aggregate(lambda x: reduce(operation, x)) super().__init__(name, series_aggregate, series_aggregate) class AggregationOnPandas(dd.Aggregation): """ A special form of an aggregation, which does not apply the given function (given as attribute name) directly to the dask groupby, but via the groupby().apply() method. This is needed to call functions directly on the pandas dataframes, but should be done very carefully (as it is a performance bottleneck). """ def __init__(self, function_name: str): def _f(s): return s.apply(lambda s0: getattr(s0.dropna(), function_name)()) super().__init__(function_name, _f, _f) class AggregationSpecification: """ Most of the aggregations in SQL are already implemented 1:1 in dask and can just be called via their name (e.g. AVG is the mean). However sometimes those implemented functions only work well for some datatypes. This small container class therefore can have an custom aggregation function, which is valid for not supported dtypes. """ def __init__(self, built_in_aggregation, custom_aggregation=None): self.built_in_aggregation = built_in_aggregation self.custom_aggregation = custom_aggregation or built_in_aggregation def get_supported_aggregation(self, series): built_in_aggregation = self.built_in_aggregation # built-in aggregations work well for numeric types if pd.api.types.is_numeric_dtype(series.dtype): return built_in_aggregation # Todo: Add Categorical when support comes to dask-sql if built_in_aggregation in ["min", "max"]: if pd.api.types.is_datetime64_any_dtype(series.dtype): return built_in_aggregation if pd.api.types.is_string_dtype(series.dtype): # If dask_cudf strings dtype, return built-in aggregation if is_cudf_type(series): return built_in_aggregation # with pandas StringDtype built-in aggregations work if isinstance(series.dtype, pd.StringDtype): return built_in_aggregation return self.custom_aggregation class DaskAggregatePlugin(BaseRelPlugin): """ A DaskAggregate is used in GROUP BY clauses, but also when aggregating a function over the full dataset. In the first case we need to find out which columns we need to group over, in the second case we "cheat" and add a 1-column to the dataframe, which allows us to reuse every aggregation function we already know of. As NULLs are not groupable in dask, we handle them special by adding a temporary column which is True for all NULL values and False otherwise (and also group by it). The rest is just a lot of column-name-bookkeeping. Fortunately calcite will already make sure, that each aggregation function will only every be called with a single input column (by splitting the inner calculation to a step before). Open TODO: So far we are following the dask default to only have a single partition after the group by (which is usual a reasonable assumption). It would be nice to control these things via HINTs. """ class_name = ["Aggregate", "Distinct"] AGGREGATION_MAPPING = { "sum": AggregationSpecification("sum", AggregationOnPandas("sum")), "$sum0": AggregationSpecification("sum", AggregationOnPandas("sum")), "any_value": AggregationSpecification( dd.Aggregation( "any_value", lambda s: s.sample(n=1).values, lambda s0: s0.sample(n=1).values, ) ), "avg": AggregationSpecification("mean", AggregationOnPandas("mean")), "stddev": AggregationSpecification("std", AggregationOnPandas("std")), "stddevsamp": AggregationSpecification("std", AggregationOnPandas("std")), "stddev_samp": AggregationSpecification("std", AggregationOnPandas("std")), "stddevpop": AggregationSpecification( dd.Aggregation( "stddevpop", lambda s: (s.count(), s.sum(), s.agg(lambda x: (x**2).sum())), lambda count, sum, sum_of_squares: ( count.sum(), sum.sum(), sum_of_squares.sum(), ), lambda count, sum, sum_of_squares: ( (sum_of_squares / count) - (sum / count) ** 2 ) ** (1 / 2), ) ), "stddev_pop": AggregationSpecification( dd.Aggregation( "stddev_pop", lambda s: (s.count(), s.sum(), s.agg(lambda x: (x**2).sum())), lambda count, sum, sum_of_squares: ( count.sum(), sum.sum(), sum_of_squares.sum(), ), lambda count, sum, sum_of_squares: ( (sum_of_squares / count) - (sum / count) ** 2 ) ** (1 / 2), ) ), "bit_and": AggregationSpecification( ReduceAggregation("bit_and", operator.and_) ), "bit_or": AggregationSpecification(ReduceAggregation("bit_or", operator.or_)), "bit_xor": AggregationSpecification(ReduceAggregation("bit_xor", operator.xor)), "count": AggregationSpecification("count"), "every": AggregationSpecification( dd.Aggregation("every", lambda s: s.all(), lambda s0: s0.all()) ), "max": AggregationSpecification("max", AggregationOnPandas("max")), "min": AggregationSpecification("min", AggregationOnPandas("min")), "single_value": AggregationSpecification("first"), # is null was checked earlier, now only need to compute the sum the non null values "regr_count": AggregationSpecification("sum", AggregationOnPandas("sum")), "regr_syy": AggregationSpecification( dd.Aggregation( "regr_syy", lambda s: (s.count(), s.sum(), s.agg(lambda x: (x**2).sum())), lambda count, sum, sum_of_squares: ( count.sum(), sum.sum(), sum_of_squares.sum(), ), lambda count, sum, sum_of_squares: ( sum_of_squares - (sum * (sum / count)) ), ) ), "regr_sxx": AggregationSpecification( dd.Aggregation( "regr_sxx", lambda s: (s.count(), s.sum(), s.agg(lambda x: (x**2).sum())), lambda count, sum, sum_of_squares: ( count.sum(), sum.sum(), sum_of_squares.sum(), ), lambda count, sum, sum_of_squares: ( sum_of_squares - (sum * (sum / count)) ), ) ), "variancepop": AggregationSpecification( dd.Aggregation( "variancepop", lambda s: (s.count(), s.sum(), s.agg(lambda x: (x**2).sum())), lambda count, sum, sum_of_squares: ( count.sum(), sum.sum(), sum_of_squares.sum(), ), lambda count, sum, sum_of_squares: ( (sum_of_squares / count) - (sum / count) ** 2 ), ) ), "variance_pop": AggregationSpecification( dd.Aggregation( "variance_pop", lambda s: (s.count(), s.sum(), s.agg(lambda x: (x**2).sum())), lambda count, sum, sum_of_squares: ( count.sum(), sum.sum(), sum_of_squares.sum(), ), lambda count, sum, sum_of_squares: ( (sum_of_squares / count) - (sum / count) ** 2 ), ) ), } def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer: (dc,) = self.assert_inputs(rel, 1, context) agg = rel.aggregate() df = dc.df cc = dc.column_container # We make our life easier with having unique column names cc = cc.make_unique() group_exprs = agg.getGroupSets() group_columns = ( agg.getDistinctColumns() if agg.isDistinctNode() else [group_expr.column_name(rel) for group_expr in group_exprs] ) dc = DataContainer(df, cc) if not group_columns: # There was actually no GROUP BY specified in the SQL # Still, this plan can also be used if we need to aggregate something over the full # data sample # To reuse the code, we just create a new column at the end with a single value logger.debug("Performing full-table aggregation") # Do all aggregates df_agg, output_column_order, cc = self._do_aggregations( rel, dc, group_columns, context, ) # SQL does not care about the index, but if group columns were specified we'll want to keep those df_agg = df_agg.reset_index(drop=(not group_columns)) def try_get_backend_by_frontend_name(oc): try: return cc.get_backend_by_frontend_name(oc) except KeyError: return oc backend_output_column_order = [ try_get_backend_by_frontend_name(oc) for oc in output_column_order ] cc = ColumnContainer(df_agg.columns).limit_to(backend_output_column_order) cc = self.fix_column_to_row_type(cc, rel.getRowType()) dc = DataContainer(df_agg, cc) dc = self.fix_dtype_to_row_type(dc, rel.getRowType()) return dc def _do_aggregations( self, rel: "LogicalPlan", dc: DataContainer, group_columns: list[str], context: "dask_sql.Context", ) -> tuple[dd.DataFrame, list[str]]: """ Main functionality: return the result dataframe and the output column order """ df = dc.df cc = dc.column_container # We might need it later. # If not, lets hope that adding a single column should not # be a huge problem... additional_column_name = new_temporary_column(df) df = df.assign(**{additional_column_name: 1}) # Add an entry for every grouped column, as SQL wants them first output_column_order = group_columns.copy() # Collect all aggregations we need to do ( collected_aggregations, output_column_order, df, cc, ) = self._collect_aggregations( rel, df, cc, context, additional_column_name, output_column_order ) groupby_agg_options = dask_config.get("sql.aggregate") if not collected_aggregations: backend_names = [ cc.get_backend_by_frontend_name(group_name) for group_name in group_columns ] return ( df[backend_names].drop_duplicates(**groupby_agg_options), output_column_order, cc, ) # Now we can go ahead and use these grouped aggregations # to perform the actual aggregation # It is very important to start with the non-filtered entry. # Otherwise we might loose some entries in the grouped columns df_result = None key = (None, None) if key in collected_aggregations: aggregations = collected_aggregations.pop(key) df_result = self._perform_aggregation( DataContainer(df, cc), None, None, aggregations, additional_column_name, group_columns, groupby_agg_options, ) # Now we can also the the rest for ( filter_column, distinct_column, ), aggregations in collected_aggregations.items(): agg_result = self._perform_aggregation( DataContainer(df, cc), filter_column, distinct_column, aggregations, additional_column_name, group_columns, groupby_agg_options, ) # ... and finally concat the new data with the already present columns if df_result is None: df_result = agg_result else: df_result = df_result.assign( **{col: agg_result[col] for col in agg_result.columns} ) return df_result, output_column_order, cc def _collect_aggregations( self, rel: "LogicalPlan", df: dd.DataFrame, cc: ColumnContainer, context: "dask_sql.Context", additional_column_name: str, output_column_order: list[str], ) -> tuple[ dict[tuple[str, str], list[tuple[str, str, Any]]], list[str], dd.DataFrame ]: """ Collect all aggregations together, which have the same filter column so that the aggregations only need to be done once. Returns the aggregations as mapping filter_column -> List of Aggregations where the aggregations are in the form (input_col, output_col, aggregation function (or string)) """ dc = DataContainer(df, cc) agg = rel.aggregate() input_rel = rel.get_inputs()[0] collected_aggregations = defaultdict(list) # convert and assign any input/filter columns that don't currently exist new_columns = {} for expr in agg.getNamedAggCalls(): assert expr.getExprType() in { "Alias", "AggregateFunction", "AggregateUDF", }, "Do not know how to handle this case!" for input_expr in agg.getArgs(expr): input_col = input_expr.column_name(input_rel) if input_col not in cc._frontend_backend_mapping: random_name = new_temporary_column(df) new_columns[random_name] = RexConverter.convert( input_rel, input_expr, dc, context=context ) cc = cc.add(input_col, random_name) filter_expr = expr.getFilterExpr() if filter_expr is not None: filter_col = filter_expr.column_name(input_rel) if filter_col not in cc._frontend_backend_mapping: random_name = new_temporary_column(df) new_columns[random_name] = RexConverter.convert( input_rel, filter_expr, dc, context=context ) cc = cc.add(filter_col, random_name) if new_columns: df = df.assign(**new_columns) for expr in agg.getNamedAggCalls(): schema_name = context.schema_name aggregation_name = agg.getAggregationFuncName(expr).lower() # Gather information about input columns inputs = agg.getArgs(expr) if aggregation_name == "regr_count": is_null = IsNullOperation() two_columns_proxy = new_temporary_column(df) if len(inputs) == 1: # calcite some times gives one input/col to regr_count and # another col has filter column col1 = cc.get_backend_by_frontend_name( inputs[0].column_name(input_rel) ) df = df.assign(**{two_columns_proxy: (~is_null(df[col1]))}) else: col1 = cc.get_backend_by_frontend_name( inputs[0].column_name(input_rel) ) col2 = cc.get_backend_by_frontend_name( inputs[1].column_name(input_rel) ) # both cols should be not null df = df.assign( **{ two_columns_proxy: ( ~is_null(df[col1]) & (~is_null(df[col2])) ) } ) input_col = two_columns_proxy elif aggregation_name == "regr_syy": input_col = inputs[0].column_name(input_rel) elif aggregation_name == "regr_sxx": input_col = inputs[1].column_name(input_rel) elif len(inputs) == 1: input_col = inputs[0].column_name(input_rel) elif len(inputs) == 0: input_col = additional_column_name else: raise NotImplementedError("Can not cope with more than one input") filter_expr = expr.getFilterExpr() if filter_expr is not None: filter_backend_col = cc.get_backend_by_frontend_name( filter_expr.column_name(input_rel) ) else: filter_backend_col = None try: # This unifies CPU and GPU behavior by ensuring that performing a # sum on a null column results in null and not 0 if aggregation_name == "sum" and isinstance(df._meta, pd.DataFrame): aggregation_function = AggregationSpecification( dd.Aggregation( name="custom_sum", chunk=lambda s: s.sum(min_count=1), agg=lambda s0: s0.sum(min_count=1), ) ) else: aggregation_function = self.AGGREGATION_MAPPING[aggregation_name] except KeyError: try: aggregation_function = context.schema[schema_name].functions[ aggregation_name ] except KeyError: # pragma: no cover raise NotImplementedError( f"Aggregation function {aggregation_name} not implemented (yet)." ) if isinstance(aggregation_function, AggregationSpecification): backend_name = cc.get_backend_by_frontend_name(input_col) aggregation_function = aggregation_function.get_supported_aggregation( df[backend_name] ) # Finally, extract the output column name output_col = expr.toString() # Store the aggregation collected_aggregations[ (filter_backend_col, backend_name if expr.isDistinctAgg() else None) ].append((input_col, output_col, aggregation_function)) output_column_order.append(output_col) return collected_aggregations, output_column_order, df, cc def _perform_aggregation( self, dc: DataContainer, filter_column: str, distinct_column: str, aggregations: list[tuple[str, str, Any]], additional_column_name: str, group_columns: list[str], groupby_agg_options: dict[str, Any] = {}, ): tmp_df = dc.df # format aggregations for Dask aggregations_dict = defaultdict(dict) for aggregation in aggregations: input_col, output_col, aggregation_f = aggregation input_col = dc.column_container.get_backend_by_frontend_name(input_col) # There can be cases where certain Expression values can be present here that # need to remain here until the projection phase. If we get a KeyError here # we assume one of those cases. try: output_col = dc.column_container.get_backend_by_frontend_name( output_col ) except KeyError: logger.debug(f"Using original output_col value of '{output_col}'") aggregations_dict[input_col][output_col] = aggregation_f group_columns = [ dc.column_container.get_backend_by_frontend_name(group_name) for group_name in group_columns ] # filter dataframe if specified if filter_column: filter_expression = tmp_df[filter_column] tmp_df = tmp_df[filter_expression] logger.debug(f"Filtered by {filter_column} before aggregation.") if distinct_column: tmp_df = tmp_df.drop_duplicates( subset=(group_columns + [distinct_column]), **groupby_agg_options ) logger.debug( f"Dropped duplicates from {distinct_column} before aggregation." ) # we might need a temporary column name if no groupby columns are specified if additional_column_name is None: additional_column_name = new_temporary_column(dc.df) # perform groupby operation grouped_df = tmp_df.groupby( by=(group_columns or [additional_column_name]), dropna=False ) # apply the aggregation(s) logger.debug(f"Performing aggregation {dict(aggregations_dict)}") agg_result = grouped_df.agg(aggregations_dict, **groupby_agg_options) for col in agg_result.columns: logger.debug(col) # fix the column names to a single level agg_result.columns = agg_result.columns.get_level_values(-1) return agg_result ================================================ FILE: dask_sql/physical/rel/logical/cross_join.py ================================================ import logging from typing import TYPE_CHECKING import dask_sql.utils as utils from dask_sql.datacontainer import ColumnContainer, DataContainer from dask_sql.physical.rel.base import BaseRelPlugin if TYPE_CHECKING: import dask_sql from dask_sql._datafusion_lib import LogicalPlan logger = logging.getLogger(__name__) class DaskCrossJoinPlugin(BaseRelPlugin): """ While similar to `DaskJoinPlugin` a `CrossJoin` has enough of a differing structure to justify its own plugin. This in turn limits the number of Dask tasks that are generated for `CrossJoin`'s when compared to a standard `Join` """ class_name = "CrossJoin" def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer: # We now have two inputs (from left and right), so we fetch them both dc_lhs, dc_rhs = self.assert_inputs(rel, 2, context) df_lhs = dc_lhs.df df_rhs = dc_rhs.df # Create a 'key' column in both DataFrames to join on cross_join_key = utils.new_temporary_column(df_lhs) df_lhs[cross_join_key] = 1 df_rhs[cross_join_key] = 1 result = df_lhs.merge(df_rhs, on=cross_join_key, suffixes=("", "0")).drop( cross_join_key, 1 ) cc = ColumnContainer(result.columns) # Rename columns like the rel specifies row_type = rel.getRowType() field_specifications = [str(f) for f in row_type.getFieldNames()] cc = cc.rename( { from_col: to_col for from_col, to_col in zip(cc.columns, field_specifications) } ) cc = self.fix_column_to_row_type(cc, row_type) return DataContainer(result, cc) ================================================ FILE: dask_sql/physical/rel/logical/empty.py ================================================ import logging from typing import TYPE_CHECKING import dask.dataframe as dd import pandas as pd from dask_sql.datacontainer import ColumnContainer, DataContainer from dask_sql.physical.rel.base import BaseRelPlugin if TYPE_CHECKING: import dask_sql from dask_sql._datafusion_lib import LogicalPlan logger = logging.getLogger(__name__) class DaskEmptyRelationPlugin(BaseRelPlugin): """ When a SQL query does not contain a target table, this plugin is invoked to create an empty DataFrame that the remaining expressions can operate against. """ class_name = "EmptyRelation" def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer: col_names = ( rel.empty_relation().emptyColumnNames() if len(rel.empty_relation().emptyColumnNames()) > 0 else ["_empty"] ) data = None if len(rel.empty_relation().emptyColumnNames()) > 0 else [0] return DataContainer( dd.from_pandas(pd.DataFrame(data, columns=col_names), npartitions=1), ColumnContainer(col_names), ) ================================================ FILE: dask_sql/physical/rel/logical/explain.py ================================================ from typing import TYPE_CHECKING from dask_sql.physical.rel.base import BaseRelPlugin if TYPE_CHECKING: import dask_sql from dask_sql._datafusion_lib import LogicalPlan class ExplainPlugin(BaseRelPlugin): """ Explain is used to explain the query with the EXPLAIN keyword """ class_name = "Explain" def convert(self, rel: "LogicalPlan", context: "dask_sql.Context"): explain_strings = rel.explain().getExplainString() return "\n".join(explain_strings) ================================================ FILE: dask_sql/physical/rel/logical/filter.py ================================================ import logging from typing import TYPE_CHECKING, List, Union import dask.config as dask_config import dask.dataframe as dd import numpy as np from dask_sql.datacontainer import DataContainer from dask_sql.physical.rel.base import BaseRelPlugin from dask_sql.physical.rex import RexConverter from dask_sql.physical.utils.filter import attempt_predicate_pushdown if TYPE_CHECKING: import dask_sql from dask_sql._datafusion_lib import LogicalPlan logger = logging.getLogger(__name__) def filter_or_scalar( df: dd.DataFrame, filter_condition: Union[np.bool_, dd.Series], add_filters: List = None, ): """ Some (complex) SQL queries can lead to a strange condition which is always true or false. We do not need to filter in this case. See https://github.com/dask-contrib/dask-sql/issues/87. """ if np.isscalar(filter_condition): if not filter_condition: # pragma: no cover # empty dataset logger.warning("Join condition is always false - returning empty dataset") return df.head(0, compute=False) else: return df # In SQL, a NULL in a boolean is False on filtering filter_condition = filter_condition.fillna(False) out = df[filter_condition] # dask-expr should implicitly handle predicate pushdown if dask_config.get("sql.predicate_pushdown") and not dd._dask_expr_enabled(): return attempt_predicate_pushdown(out, add_filters=add_filters) else: return out class DaskFilterPlugin(BaseRelPlugin): """ DaskFilter is used on WHERE clauses. We just evaluate the filter (which is of type RexNode) and apply it """ class_name = "Filter" def convert( self, rel: "LogicalPlan", context: "dask_sql.Context", ) -> DataContainer: (dc,) = self.assert_inputs(rel, 1, context) df = dc.df cc = dc.column_container filter = rel.filter() # Every logic is handled in the RexConverter # we just need to apply it here condition = filter.getCondition() df_condition = RexConverter.convert(rel, condition, dc, context=context) df = filter_or_scalar(df, df_condition) cc = self.fix_column_to_row_type(cc, rel.getRowType()) return DataContainer(df, cc) ================================================ FILE: dask_sql/physical/rel/logical/join.py ================================================ import logging import operator import warnings from functools import reduce from typing import TYPE_CHECKING import dask.dataframe as dd from dask import config as dask_config from dask_sql.datacontainer import ColumnContainer, DataContainer from dask_sql.physical.rel.base import BaseRelPlugin from dask_sql.physical.rel.logical.filter import filter_or_scalar from dask_sql.physical.rex import RexConverter from dask_sql.utils import is_cudf_type if TYPE_CHECKING: import dask_sql from dask_sql._datafusion_lib import Expression, LogicalPlan logger = logging.getLogger(__name__) class DaskJoinPlugin(BaseRelPlugin): """ A DaskJoin is used when (surprise) joining two tables. SQL allows for quite complicated joins with difficult conditions. dask/pandas only knows about equijoins on a specific column. We use a trick, which is also used in e.g. blazingSQL: we split the join condition into two parts: * everything which is an equijoin * the rest The first part is then used for the dask merging, whereas the second part is just applied as a filter afterwards. This will make joining more time-consuming that is needs to be but so far, it is the only solution... """ class_name = "Join" JOIN_TYPE_MAPPING = { "INNER": "inner", "LEFT": "left", "RIGHT": "right", "FULL": "outer", "LEFTSEMI": "leftsemi", "LEFTANTI": "leftanti", } def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer: # Joining is a bit more complicated, so lets do it in steps: join = rel.join() # 1. We now have two inputs (from left and right), so we fetch them both dc_lhs, dc_rhs = self.assert_inputs(rel, 2, context) cc_lhs = dc_lhs.column_container cc_rhs = dc_rhs.column_container # 2. dask's merge will do some smart things with columns, which have the same name # on lhs an rhs (which also includes reordering). # However, that will confuse our column numbering in SQL. # So we make our life easier by converting the column names into unique names # We will convert back in the end cc_lhs_renamed = cc_lhs.make_unique("lhs") cc_rhs_renamed = cc_rhs.make_unique("rhs") dc_lhs_renamed = DataContainer(dc_lhs.df, cc_lhs_renamed) dc_rhs_renamed = DataContainer(dc_rhs.df, cc_rhs_renamed) df_lhs_renamed = dc_lhs_renamed.assign() df_rhs_renamed = dc_rhs_renamed.assign() join_type = join.getJoinType() join_type = self.JOIN_TYPE_MAPPING[str(join_type)] # TODO: update with correct implementation of leftsemi for CPU # https://github.com/dask-contrib/dask-sql/issues/1190 if join_type == "leftsemi" and not is_cudf_type(df_lhs_renamed): join_type = "inner" # 3. The join condition can have two forms, that we can understand # (a) a = b # (b) X AND Y AND a = b AND Z ... (can also be multiple a = b) # The first case is very simple and we do not need any additional filter # In the second case we do a merge on all the a = b, # and then apply a filter using the other expressions. # In all other cases, we need to do a full table cross join and filter afterwards. # As this is probably non-sense for large tables, but there is no other # known solution so far. join_condition = join.getCondition() lhs_on, rhs_on, filter_condition = None, None, None # A user can write certain queries that really should be `cross join` queries # that will still enter this portion of the logic. IF the join_condition is # None that means there are no conditions to join on. This means a cross join. # By not entering this body during that condition we ensure that later on in # processing we perform a cross join. if join_condition is not None: lhs_on, rhs_on, filter_condition = self._split_join_condition( join_condition ) # lhs_on and rhs_on are the indices of the columns to merge on. # The given column indices are for the full, merged table which consists # of lhs and rhs put side-by-side (in this order) # We therefore need to normalize the rhs indices relative to the rhs table. rhs_on = [index - len(df_lhs_renamed.columns) for index in rhs_on] # 4. dask can only merge on the same column names. # We therefore create new columns on purpose, which have a distinct name. assert len(lhs_on) == len(rhs_on) if lhs_on: # 5. Now we can finally merge on these columns # The resulting dataframe will contain all (renamed) columns from the lhs and rhs # plus the added columns df = self._join_on_columns( df_lhs_renamed, df_rhs_renamed, lhs_on, rhs_on, join_type, ) else: # 5. We are in the complex join case # where we have no column to merge on # This means we have no other chance than to merge # everything with everything... # TODO: we should implement a shortcut # for filter conditions that are always false df = dd.merge( df_lhs_renamed.assign(common=1), df_rhs_renamed.assign(common=1), on="common", ).drop(columns="common") warnings.warn( "Need to do a cross-join, which is typically very resource heavy", ResourceWarning, ) # 6. So the next step is to make sure # we have the correct column order (and to remove the temporary join columns) if join_type in ("leftsemi", "leftanti"): correct_column_order = list(df_lhs_renamed.columns) else: correct_column_order = list(df_lhs_renamed.columns) + list( df_rhs_renamed.columns ) cc = ColumnContainer(df.columns).limit_to(correct_column_order) # and to rename them like the rel specifies row_type = rel.getRowType() field_specifications = [str(f) for f in row_type.getFieldNames()] if join_type in ("leftsemi", "leftanti"): field_specifications = field_specifications[: len(cc.columns)] cc = cc.rename( { from_col: to_col for from_col, to_col in zip(cc.columns, field_specifications) } ) cc = self.fix_column_to_row_type(cc, row_type, join_type) dc = DataContainer(df, cc) # 7. Last but not least we apply any filters by and-chaining together the filters if filter_condition: # This line is a bit of code duplication with RexCallPlugin - but I guess it is worth to keep it separate filter_condition = reduce( operator.and_, [ RexConverter.convert(rel, rex, dc, context=context) for rex in filter_condition ], ) logger.debug(f"Additionally applying filter {filter_condition}") df = filter_or_scalar(df, filter_condition) dc = DataContainer(df, cc) dc = self.fix_dtype_to_row_type(dc, rel.getRowType(), join_type) # # Rename underlying DataFrame column names back to their original values before returning # df = dc.assign() # dc = DataContainer(df, ColumnContainer(cc.columns)) return dc def _join_on_columns( self, df_lhs_renamed: dd.DataFrame, df_rhs_renamed: dd.DataFrame, lhs_on: list[str], rhs_on: list[str], join_type: str, ) -> dd.DataFrame: # SQL compatibility: when joining on columns that # contain NULLs, pandas will actually happily # keep those NULLs. That is however not compatible with # SQL, so we get rid of them here if join_type in ["inner", "right"]: df_lhs_filter = reduce( operator.and_, [~df_lhs_renamed.iloc[:, index].isna() for index in lhs_on], ) df_lhs_renamed = df_lhs_renamed[df_lhs_filter] if join_type in ["inner", "left", "leftanti", "leftsemi"]: df_rhs_filter = reduce( operator.and_, [~df_rhs_renamed.iloc[:, index].isna() for index in rhs_on], ) df_rhs_renamed = df_rhs_renamed[df_rhs_filter] lhs_columns_to_add = { f"common_{i}": df_lhs_renamed["lhs_" + str(index)] for i, index in enumerate(lhs_on) } rhs_columns_to_add = { f"common_{i}": df_rhs_renamed.iloc[:, index] for i, index in enumerate(rhs_on) } df_lhs_with_tmp = df_lhs_renamed.assign(**lhs_columns_to_add) df_rhs_with_tmp = df_rhs_renamed.assign(**rhs_columns_to_add) added_columns = list(lhs_columns_to_add.keys()) broadcast = dask_config.get("sql.join.broadcast") if join_type == "leftanti" and not is_cudf_type(df_lhs_with_tmp): df = df_lhs_with_tmp.merge( df_rhs_with_tmp, on=added_columns, how="left", broadcast=broadcast, indicator=True, ).drop(columns=added_columns) df = df[df["_merge"] == "left_only"].drop( columns=["_merge"] + list(df_rhs_with_tmp.columns), errors="ignore" ) else: df = df_lhs_with_tmp.merge( df_rhs_with_tmp, on=added_columns, how=join_type, broadcast=broadcast, ).drop(columns=added_columns) return df def _split_join_condition( self, join_condition: "Expression" ) -> tuple[list[str], list[str], list["Expression"]]: if str(join_condition.getRexType()) in ["RexType.Literal", "RexType.Reference"]: return [], [], [join_condition] elif not str(join_condition.getRexType()) == "RexType.Call": raise NotImplementedError("Can not understand join condition.") lhs_on = [] rhs_on = [] filter_condition = [] try: lhs_on, rhs_on, filter_condition_part = self._extract_lhs_rhs( join_condition ) filter_condition.extend(filter_condition_part) except AssertionError: filter_condition.append(join_condition) if lhs_on and rhs_on: return lhs_on, rhs_on, filter_condition return [], [], [join_condition] def _extract_lhs_rhs(self, rex): assert str(rex.getRexType()) == "RexType.Call" operator_name = str(rex.getOperatorName()) assert operator_name in ["=", "AND"] operands = rex.getOperands() assert len(operands) == 2 if operator_name == "=": operand_lhs = operands[0] operand_rhs = operands[1] if ( str(operand_lhs.getRexType()) == "RexType.Reference" and str(operand_rhs.getRexType()) == "RexType.Reference" ): lhs_index = operand_lhs.getIndex() rhs_index = operand_rhs.getIndex() # The rhs table always comes after the lhs # table. Therefore we have a very simple # way of checking, which index comes from which # input if lhs_index > rhs_index: lhs_index, rhs_index = rhs_index, lhs_index return [lhs_index], [rhs_index], [] raise AssertionError( "Invalid join condition" ) # pragma: no cover. Do not how how it could be triggered. else: lhs_indices = [] rhs_indices = [] filter_conditions = [] for operand in operands: try: lhs_index, rhs_index, filter_condition = self._extract_lhs_rhs( operand ) filter_conditions.extend(filter_condition) lhs_indices.extend(lhs_index) rhs_indices.extend(rhs_index) except AssertionError: filter_conditions.append(operand) return lhs_indices, rhs_indices, filter_conditions ================================================ FILE: dask_sql/physical/rel/logical/limit.py ================================================ from typing import TYPE_CHECKING import dask.dataframe as dd from dask import config as dask_config from dask.blockwise import Blockwise from dask.highlevelgraph import MaterializedLayer from dask.layers import DataFrameIOLayer from dask_sql.datacontainer import DataContainer from dask_sql.physical.rel.base import BaseRelPlugin from dask_sql.physical.rex import RexConverter if TYPE_CHECKING: import dask_sql from dask_sql._datafusion_lib import LogicalPlan class DaskLimitPlugin(BaseRelPlugin): """ Limit is used to only get a certain part of the dataframe (LIMIT). """ class_name = "Limit" def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer: (dc,) = self.assert_inputs(rel, 1, context) df = dc.df cc = dc.column_container # Retrieve the RexType::Literal values from the `LogicalPlan` Limit # Fetch -> LIMIT # Skip -> OFFSET limit = RexConverter.convert(rel, rel.limit().getFetch(), df, context=context) offset = RexConverter.convert(rel, rel.limit().getSkip(), df, context=context) # apply offset to limit if specified if limit and offset: limit += offset # apply limit and/or offset to DataFrame df = self._apply_limit(df, limit, offset) cc = self.fix_column_to_row_type(cc, rel.getRowType()) # No column type has changed, so no need to cast again return DataContainer(df, cc) def _apply_limit(self, df: dd.DataFrame, limit: int, offset: int) -> dd.DataFrame: """ Limit the dataframe to the window [offset, limit]. Unfortunately, Dask does not currently support row selection through `iloc`, so this must be done using a custom partition function. However, it is sometimes possible to compute this window using `head` when an `offset` is not specified. """ # if no offset is specified we can use `head` to compute the window if not offset: # if `check-first-partition` enabled, check if we have a relatively simple Dask graph and if so, # check if the first partition contains our desired window if ( dask_config.get("sql.limit.check-first-partition") and not dd._dask_expr_enabled() and all( [ isinstance( layer, (DataFrameIOLayer, Blockwise, MaterializedLayer) ) for layer in df.dask.layers.values() ] ) and limit <= len(df.partitions[0]) ): return df.head(limit, compute=False) return df.head(limit, npartitions=-1, compute=False) # compute the size of each partition # TODO: compute `cumsum` here when dask#9067 is resolved partition_borders = df.map_partitions(lambda x: len(x)) def limit_partition_func(df, partition_borders, partition_info=None): """Limit the partition to values contained within the specified window, returning an empty dataframe if there are none""" # with dask-expr we may need to explicitly compute here if hasattr(partition_borders, "compute"): partition_borders = partition_borders.compute() # TODO: remove the `cumsum` call here when dask#9067 is resolved partition_borders = partition_borders.cumsum().to_dict() partition_index = ( partition_info["number"] if partition_info is not None else 0 ) partition_border_left = ( partition_borders[partition_index - 1] if partition_index > 0 else 0 ) partition_border_right = partition_borders[partition_index] if (limit and limit < partition_border_left) or ( offset >= partition_border_right ): return df.iloc[0:0] from_index = max(offset - partition_border_left, 0) to_index = ( min(limit, partition_border_right) if limit else partition_border_right ) - partition_border_left return df.iloc[from_index:to_index] return df.map_partitions( limit_partition_func, partition_borders=partition_borders, ) ================================================ FILE: dask_sql/physical/rel/logical/project.py ================================================ import logging from typing import TYPE_CHECKING from dask_sql._datafusion_lib import RexType from dask_sql.datacontainer import DataContainer from dask_sql.physical.rel.base import BaseRelPlugin from dask_sql.physical.rex import RexConverter from dask_sql.utils import new_temporary_column if TYPE_CHECKING: import dask_sql from dask_sql._datafusion_lib import LogicalPlan logger = logging.getLogger(__name__) class DaskProjectPlugin(BaseRelPlugin): """ A DaskProject is used to (a) apply expressions to the columns and (b) only select a subset of the columns """ class_name = "Projection" def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer: # Get the input of the previous step (dc,) = self.assert_inputs(rel, 1, context) df = dc.df cc = dc.column_container # Collect all (new) columns proj = rel.projection() named_projects = proj.getNamedProjects() column_names = [] new_columns = {} new_mappings = {} # Collect all (new) columns this Projection will limit to for key, expr in named_projects: key = str(key) column_names.append(key) # shortcut: if we have a column already, there is no need to re-assign it again # this is only the case if the expr is a RexInputRef if expr.getRexType() == RexType.Reference: index = expr.getIndex() backend_column_name = cc.get_backend_by_frontend_index(index) logger.debug( f"Not re-adding the same column {key} (but just referencing it)" ) new_mappings[key] = backend_column_name else: random_name = new_temporary_column(df) new_columns[random_name] = RexConverter.convert( rel, expr, dc, context=context ) logger.debug(f"Adding a new column {key} out of {expr}") new_mappings[key] = random_name # Actually add the new columns if new_columns: df = df.assign(**new_columns) # and the new mappings for key, backend_column_name in new_mappings.items(): cc = cc.add(key, backend_column_name) # Make sure the order is correct cc = cc.limit_to(column_names) cc = self.fix_column_to_row_type(cc, rel.getRowType()) dc = DataContainer(df, cc) dc = self.fix_dtype_to_row_type(dc, rel.getRowType()) return dc ================================================ FILE: dask_sql/physical/rel/logical/sample.py ================================================ import logging from typing import TYPE_CHECKING import numpy as np from dask_sql.datacontainer import DataContainer from dask_sql.physical.rel.base import BaseRelPlugin if TYPE_CHECKING: import dask_sql from dask_sql.java import org logger = logging.getLogger(__name__) class SamplePlugin(BaseRelPlugin): """ Sample is used on TABLESAMPLE clauses. It returns only a fraction of the table, given by the number in the arguments. There exist two algorithms, SYSTEM or BERNOULLI. SYSTEM is a very fast algorithm, which works on partition level: a partition is kept with a probability given by the percentage. This algorithm will - especially for very small numbers of partitions - give wrong results. Only choose it when you really have too much data to apply BERNOULLI (which might never be the case in real world applications). BERNOULLI samples each row separately and will still give only an approximate fraction, but much closer to the expected. """ class_name = "com.dask.sql.nodes.DaskSample" def convert( self, rel: "org.apache.calcite.rel.RelNode", context: "dask_sql.Context" ) -> DataContainer: (dc,) = self.assert_inputs(rel, 1, context) df = dc.df cc = dc.column_container parameters = rel.getSamplingParameters() is_bernoulli = parameters.isBernoulli() fraction = float(parameters.getSamplingPercentage()) seed = parameters.getRepeatableSeed() if parameters.isRepeatable() else None if is_bernoulli: df = df.sample(frac=fraction, replace=False, random_state=seed) else: random_state = np.random.RandomState(seed) random_choice = random_state.choice( [True, False], size=df.npartitions, replace=True, p=[fraction, 1 - fraction], ) if random_choice.any(): df = df.partitions[random_choice] else: df = df.head(0, compute=False) return DataContainer(df, cc) ================================================ FILE: dask_sql/physical/rel/logical/sort.py ================================================ from typing import TYPE_CHECKING from dask_sql.datacontainer import DataContainer from dask_sql.physical.rel.base import BaseRelPlugin from dask_sql.physical.utils.sort import apply_sort if TYPE_CHECKING: import dask_sql from dask_sql._datafusion_lib import LogicalPlan class DaskSortPlugin(BaseRelPlugin): """ DaskSort is used to sort by columns (ORDER BY). """ class_name = "Sort" def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer: (dc,) = self.assert_inputs(rel, 1, context) df = dc.df cc = dc.column_container sort_plan = rel.sort() sort_expressions = sort_plan.getCollation() sort_columns = [ cc.get_backend_by_frontend_name(expr.column_name(rel)) for expr in sort_expressions ] sort_ascending = [expr.isSortAscending() for expr in sort_expressions] sort_null_first = [expr.isSortNullsFirst() for expr in sort_expressions] sort_num_rows = sort_plan.getNumRows() df = apply_sort( df, sort_columns, sort_ascending, sort_null_first, sort_num_rows ) cc = self.fix_column_to_row_type(cc, rel.getRowType()) # No column type has changed, so no need to cast again return DataContainer(df, cc) ================================================ FILE: dask_sql/physical/rel/logical/subquery_alias.py ================================================ from typing import TYPE_CHECKING from dask_sql.datacontainer import DataContainer from dask_sql.physical.rel.base import BaseRelPlugin if TYPE_CHECKING: import dask_sql from dask_sql._datafusion_lib import LogicalPlan class SubqueryAlias(BaseRelPlugin): """ SubqueryAlias is used to assign an alias to a table and/or subquery """ class_name = "SubqueryAlias" def convert(self, rel: "LogicalPlan", context: "dask_sql.Context"): (dc,) = self.assert_inputs(rel, 1, context) cc = dc.column_container alias = rel.subquery_alias().getAlias() return DataContainer( dc.df, cc.rename( { col: renamed_col for col, renamed_col in zip( cc.columns, (f"{alias}.{col.split('.')[-1]}" for col in cc.columns), ) } ), ) ================================================ FILE: dask_sql/physical/rel/logical/table_scan.py ================================================ import logging import operator from functools import reduce from typing import TYPE_CHECKING from dask.dataframe import _dask_expr_enabled from dask.utils_test import hlg_layer from dask_sql.datacontainer import DataContainer from dask_sql.physical.rel.base import BaseRelPlugin from dask_sql.physical.rel.logical.filter import filter_or_scalar from dask_sql.physical.rex import RexConverter if TYPE_CHECKING: import dask_sql from dask_sql._datafusion_lib import LogicalPlan logger = logging.getLogger(__name__) class DaskTableScanPlugin(BaseRelPlugin): """ A DaskTableScan is the main ingredient: it will get the data from the database. It is always used, when the SQL looks like SELECT .... FROM table .... We need to get the dask dataframe from the registered tables and return the requested columns from it. """ class_name = "TableScan" def convert( self, rel: "LogicalPlan", context: "dask_sql.Context", ) -> DataContainer: # There should not be any input. This is the first step. self.assert_inputs(rel, 0) # Rust table_scan instance handle table_scan = rel.table_scan() # The table(s) we need to return dask_table = rel.getTable() schema_name, table_name = (n.lower() for n in context.fqn(dask_table)) dc = context.schema[schema_name].tables[table_name] # Apply filter before projections since filter columns may not be in projections dc = self._apply_filters(table_scan, rel, dc, context) dc = self._apply_projections(table_scan, dask_table, dc) cc = dc.column_container cc = self.fix_column_to_row_type(cc, rel.getRowType()) dc = DataContainer(dc.df, cc) dc = self.fix_dtype_to_row_type(dc, rel.getRowType()) return dc def _apply_projections(self, table_scan, dask_table, dc): # If the 'TableScan' instance contains projected columns only retrieve those columns # otherwise get all projected columns from the 'Projection' instance, which is contained # in the 'RelDataType' instance, aka 'row_type' df = dc.df cc = dc.column_container if table_scan.containsProjections(): field_specifications = list( map(cc.get_backend_by_frontend_name, table_scan.getTableScanProjects()) ) # Assumes these are column projections only and field names match table column names df = df[field_specifications] else: field_specifications = [ str(f) for f in dask_table.getRowType().getFieldNames() ] cc = cc.limit_to(field_specifications) return DataContainer(df, cc) def _apply_filters(self, table_scan, rel, dc, context): df = dc.df cc = dc.column_container all_filters = table_scan.getFilters() conjunctive_dnf_filters = table_scan.getDNFFilters().filtered_exprs non_dnf_filters = table_scan.getDNFFilters().io_unfilterable_exprs if conjunctive_dnf_filters: # Extract the PyExprs from the conjunctive DNF filters filter_exprs = [f[0] for f in conjunctive_dnf_filters] if non_dnf_filters: filter_exprs.extend(non_dnf_filters) df_condition = reduce( operator.and_, [ RexConverter.convert(rel, rex, dc, context=context) for rex in filter_exprs ], ) df = filter_or_scalar( df, df_condition, add_filters=[f[1] for f in conjunctive_dnf_filters] ) elif all_filters: df_condition = reduce( operator.and_, [ RexConverter.convert(rel, rex, dc, context=context) for rex in all_filters ], ) df = filter_or_scalar(df, df_condition) if not _dask_expr_enabled(): try: logger.debug(hlg_layer(df.dask, "read-parquet").creation_info) except KeyError: pass return DataContainer(df, cc) ================================================ FILE: dask_sql/physical/rel/logical/union.py ================================================ from typing import TYPE_CHECKING import dask.dataframe as dd from dask_sql.datacontainer import ColumnContainer, DataContainer from dask_sql.physical.rel.base import BaseRelPlugin if TYPE_CHECKING: import dask_sql from dask_sql._datafusion_lib import LogicalPlan def _extract_df(obj_cc, obj_df, output_field_names): # For concatenating, they should have exactly the same fields assert len(obj_cc.columns) == len(output_field_names) obj_cc = obj_cc.rename( columns={ col: output_col for col, output_col in zip(obj_cc.columns, output_field_names) } ) obj_dc = DataContainer(obj_df, obj_cc) return obj_dc.assign() class DaskUnionPlugin(BaseRelPlugin): """ DaskUnion is used on UNION clauses. It just concatonates the two data frames. """ class_name = "Union" def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer: # Late import to remove cycling dependency from dask_sql.physical.rel.convert import RelConverter objs_dc = [ RelConverter.convert(input_rel, context) for input_rel in rel.get_inputs() ] objs_df = [obj.df for obj in objs_dc] objs_cc = [obj.column_container for obj in objs_dc] output_field_names = [str(x) for x in rel.getRowType().getFieldNames()] obj_dfs = [] for i, obj_df in enumerate(objs_df): obj_dfs.append( _extract_df( obj_cc=objs_cc[i], obj_df=obj_df, output_field_names=output_field_names, ) ) _ = [self.check_columns_from_row_type(df, rel.getRowType()) for df in obj_dfs] df = dd.concat(obj_dfs) cc = ColumnContainer(df.columns) cc = self.fix_column_to_row_type(cc, rel.getRowType()) dc = DataContainer(df, cc) dc = self.fix_dtype_to_row_type(dc, rel.getRowType()) return dc ================================================ FILE: dask_sql/physical/rel/logical/values.py ================================================ from typing import TYPE_CHECKING import dask.dataframe as dd import pandas as pd from dask_sql.datacontainer import ColumnContainer, DataContainer from dask_sql.physical.rel.base import BaseRelPlugin from dask_sql.physical.rex import RexConverter if TYPE_CHECKING: import dask_sql from dask_sql.java import org class DaskValuesPlugin(BaseRelPlugin): """ A DaskValue is a table just consisting of raw values (nothing database-dependent). For example SELECT 1 + 1; We generate a pandas dataframe and a dask dataframe out of it directly here. We assume that this will only ever be used for small data samples. """ class_name = "com.dask.sql.nodes.DaskValues" def convert( self, rel: "org.apache.calcite.rel.RelNode", context: "dask_sql.Context" ) -> DataContainer: # There should not be any input. This is the first step. self.assert_inputs(rel, 0) rex_expression_rows = list(rel.getTuples()) rows = [] for rex_expression_row in rex_expression_rows: # We convert each of the cells in the row # using a RexConverter. # As we do not have any information on the # column headers, we just name them with # their index. rows.append( { str(i): RexConverter.convert(rex_cell, None, context=context) for i, rex_cell in enumerate(rex_expression_row) } ) # TODO: we explicitely reference pandas and dask here -> might we worth making this more general # We assume here that when using the values plan, the resulting dataframe will be quite small if rows: df = pd.DataFrame(rows) else: field_names = [str(x) for x in rel.getRowType().getFieldNames()] df = pd.DataFrame(columns=field_names) df = dd.from_pandas(df, npartitions=1) cc = ColumnContainer(df.columns) cc = self.fix_column_to_row_type(cc, rel.getRowType()) dc = DataContainer(df, cc) dc = self.fix_dtype_to_row_type(dc, rel.getRowType()) return dc ================================================ FILE: dask_sql/physical/rel/logical/window.py ================================================ import logging from collections import namedtuple from functools import partial from typing import TYPE_CHECKING, Callable, Optional import dask.dataframe as dd import numpy as np import pandas as pd from pandas.api.indexers import BaseIndexer from dask_sql.datacontainer import ColumnContainer, DataContainer from dask_sql.physical.rel.base import BaseRelPlugin from dask_sql.physical.rex.convert import RexConverter from dask_sql.physical.utils.sort import sort_partition_func from dask_sql.utils import LoggableDataFrame, new_temporary_column if TYPE_CHECKING: import dask_sql from dask_sql._datafusion_lib import LogicalPlan logger = logging.getLogger(__name__) class OverOperation: def __call__(self, partitioned_group, *args) -> pd.Series: """Call the stored function""" return self.call(partitioned_group, *args) class FirstValueOperation(OverOperation): def call(self, partitioned_group, value_col): return partitioned_group[value_col].apply(lambda x: x.iloc[0]) class LastValueOperation(OverOperation): def call(self, partitioned_group, value_col): return partitioned_group[value_col].apply(lambda x: x.iloc[-1]) class SumOperation(OverOperation): def call(self, partitioned_group, value_col): return partitioned_group[value_col].sum() class CountOperation(OverOperation): def call(self, partitioned_group, value_col=None): if value_col is None: return partitioned_group.count().iloc[:, 0].fillna(0) else: return partitioned_group[value_col].count().fillna(0) class MaxOperation(OverOperation): def call(self, partitioned_group, value_col): return partitioned_group[value_col].max() class MinOperation(OverOperation): def call(self, partitioned_group, value_col): return partitioned_group[value_col].min() class AvgOperation(OverOperation): def call(self, partitioned_group, value_col): return partitioned_group[value_col].mean() class BoundDescription( namedtuple( "BoundDescription", ["is_unbounded", "is_preceding", "is_following", "is_current_row", "offset"], ) ): """ Small helper class to wrap a PyWindowFrame object. We can directly ship PyWindowFrame to workers in the future """ pass def to_bound_description( windowFrame, ) -> BoundDescription: """Convert the PyWindowFrame object to a BoundDescription representation, replacing any literals or references to constants""" return BoundDescription( is_unbounded=bool(windowFrame.isUnbounded()), is_preceding=bool(windowFrame.isPreceding()), is_following=bool(windowFrame.isFollowing()), is_current_row=bool(windowFrame.isCurrentRow()), offset=windowFrame.getOffset(), ) class Indexer(BaseIndexer): """ Window description used for complex windows with arbitrary start and end. This class is directly taken from the fugue project. """ def __init__(self, start: int, end: int): super().__init__(self, start=start, end=end) def _get_window_bounds( self, num_values: int = 0, min_periods: Optional[int] = None, center: Optional[bool] = None, closed: Optional[str] = None, ) -> tuple[np.ndarray, np.ndarray]: if self.start is None: start = np.zeros(num_values, dtype=np.int64) else: start = np.arange(self.start, self.start + num_values, dtype=np.int64) if self.start < 0: start[: -self.start] = 0 elif self.start > 0: start[-self.start :] = num_values if self.end is None: end = np.full(num_values, num_values, dtype=np.int64) else: end = np.arange(self.end + 1, self.end + 1 + num_values, dtype=np.int64) if self.end > 0: end[-self.end :] = num_values elif self.end < 0: end[: -self.end] = 0 else: # pragma: no cover raise AssertionError( "This case should have been handled before! Please report this bug" ) return start, end def get_window_bounds( self, num_values: int = 0, min_periods: Optional[int] = None, center: Optional[bool] = None, closed: Optional[str] = None, step: Optional[int] = None, ) -> tuple[np.ndarray, np.ndarray]: return self._get_window_bounds(num_values, min_periods, center, closed) def map_on_each_group( partitioned_group: pd.DataFrame, sort_columns: list[str], sort_ascending: list[bool], sort_null_first: list[bool], lower_bound: BoundDescription, upper_bound: BoundDescription, operations: list[tuple[Callable, str, list[str]]], ): """Internal function mapped on each group of the dataframe after partitioning""" # Apply sorting if sort_columns: partitioned_group = sort_partition_func( partitioned_group, sort_columns, sort_ascending, sort_null_first ) # Apply the windowing operation if lower_bound.is_unbounded and ( upper_bound.is_current_row or upper_bound.offset == 0 ): windowed_group = partitioned_group.expanding(min_periods=1) elif lower_bound.is_preceding and ( upper_bound.is_current_row or upper_bound.offset == 0 ): windowed_group = partitioned_group.rolling( window=lower_bound.offset + 1, min_periods=1, ) else: lower_offset = lower_bound.offset if not lower_bound.is_current_row else 0 if lower_bound.is_preceding and lower_offset is not None: lower_offset *= -1 upper_offset = upper_bound.offset if not upper_bound.is_current_row else 0 if upper_bound.is_preceding and upper_offset is not None: upper_offset *= -1 indexer = Indexer(lower_offset, upper_offset) windowed_group = partitioned_group.rolling(window=indexer, min_periods=1) # Calculate the results new_columns = {} for f, new_column_name, temporary_operand_columns in operations: if f is None: # This is the row_number operator. # We do not need to do any windowing column_result = range(1, len(partitioned_group) + 1) else: column_result = f(windowed_group, *temporary_operand_columns) new_columns[new_column_name] = column_result # Now apply all columns at once partitioned_group = partitioned_group.assign(**new_columns) return partitioned_group class DaskWindowPlugin(BaseRelPlugin): """ A DaskWindow is an expression, which calculates a given function over the dataframe while first optionally partitoning the data and optionally sorting it. Expressions like `F OVER (PARTITION BY x ORDER BY y)` apply f on each partition separately and sort by y before applying f. The result of this calculation has however the same length as the input dataframe - it is not an aggregation. Typical examples include ROW_NUMBER and lagging. """ class_name = "Window" OPERATION_MAPPING = { "row_number": None, # That is the easiest one: we do not even need to have any windowing. We therefore threat it separately "$sum0": SumOperation(), "sum": SumOperation(), "count": CountOperation(), "max": MaxOperation(), "min": MinOperation(), "single_value": FirstValueOperation(), "first_value": FirstValueOperation(), "last_value": LastValueOperation(), "avg": AvgOperation(), } def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer: (dc,) = self.assert_inputs(rel, 1, context) # Output to the right field names right away field_names = rel.getRowType().getFieldNames() for window in rel.window().getGroups(): dc = self._apply_window(rel, window, dc, field_names, context) # Finally, fix the output schema if needed df = dc.df cc = dc.column_container cc = self.fix_column_to_row_type(cc, rel.getRowType()) dc = DataContainer(df, cc) dc = self.fix_dtype_to_row_type(dc, rel.getRowType()) return dc def _apply_window( self, rel, window, dc: DataContainer, field_names: list[str], context: "dask_sql.Context", ): temporary_columns = [] df = dc.df cc = dc.column_container # Now extract the groupby and order information sort_columns, sort_ascending, sort_null_first = self._extract_ordering( rel, window, cc ) logger.debug( f"Before applying the function, sorting according to {sort_columns}." ) df, group_columns, temporary_columns = self._extract_groupby( df, rel, window, dc, context ) logger.debug( f"Before applying the function, partitioning according to {group_columns}." ) operations, df = self._extract_operations(rel, window, df, dc, context) for _, _, cols in operations: temporary_columns += cols newly_created_columns = [new_column for _, new_column, _ in operations] logger.debug(f"Will create {newly_created_columns} new columns") # Default window bounds when not specified as unbound preceding and current row (if no order by) # unbounded preceding and unbounded following if there's an order by if not rel.window().getWindowFrame(window): lower_bound = BoundDescription( is_unbounded=True, is_preceding=True, is_following=False, is_current_row=False, offset=None, ) upper_bound = ( BoundDescription( is_unbounded=False, is_preceding=False, is_following=False, is_current_row=True, offset=None, ) if sort_columns else BoundDescription( is_unbounded=True, is_preceding=False, is_following=True, is_current_row=False, offset=None, ) ) else: lower_bound = to_bound_description( rel.window().getWindowFrame(window).getLowerBound(), ) upper_bound = to_bound_description( rel.window().getWindowFrame(window).getUpperBound(), ) # Apply the windowing operation filled_map = partial( map_on_each_group, sort_columns=sort_columns, sort_ascending=sort_ascending, sort_null_first=sort_null_first, lower_bound=lower_bound, upper_bound=upper_bound, operations=operations, ) # TODO: That is a bit of a hack. We should really use the real column dtype meta = df._meta.assign(**{col: 0.0 for col in newly_created_columns}) df = df.groupby(group_columns, dropna=False)[df.columns.tolist()].apply( filled_map, meta=meta ) logger.debug( f"Having created a dataframe {LoggableDataFrame(df)} after windowing. Will now drop {temporary_columns}." ) df = df.drop(columns=temporary_columns).reset_index(drop=True) dc = DataContainer(df, cc) df = dc.df cc = dc.column_container for c in newly_created_columns: field_name = field_names[len(cc.columns)] cc = cc.add(field_name, c) dc = DataContainer(df, cc) logger.debug( f"Removed unneeded columns and registered new ones: {LoggableDataFrame(dc)}." ) return dc def _extract_groupby( self, df: dd.DataFrame, rel, window, dc: DataContainer, context: "dask_sql.Context", ) -> tuple[dd.DataFrame, str]: """Prepare grouping columns we can later use while applying the main function""" partition_keys = rel.window().getPartitionExprs(window) if partition_keys: group_columns = [ dc.column_container.get_backend_by_frontend_name(o.column_name(rel)) for o in partition_keys ] temporary_columns = [] else: temp_col = new_temporary_column(df) df = df.assign(**{temp_col: 1}) group_columns = [temp_col] temporary_columns = [temp_col] return df, group_columns, temporary_columns def _extract_ordering( self, rel, window, cc: ColumnContainer ) -> tuple[str, str, str]: """Prepare sorting information we can later use while applying the main function""" logger.debug( "Error is about to be encountered, FIX me when bindings are available in subsequent PR" ) # TODO: This was commented out for flake8 CI passing and needs to be handled sort_expressions = rel.window().getSortExprs(window) sort_columns = [ cc.get_backend_by_frontend_name(expr.column_name(rel)) for expr in sort_expressions ] sort_ascending = [expr.isSortAscending() for expr in sort_expressions] sort_null_first = [expr.isSortNullsFirst() for expr in sort_expressions] return sort_columns, sort_ascending, sort_null_first def _extract_operations( self, rel, window, df: dd.DataFrame, dc: DataContainer, context: "dask_sql.Context", ) -> list[tuple[Callable, str, list[str]]]: # Finally apply the actual function on each group separately operations = [] # TODO: datafusion returns only window func expression per window # This can be optimized in the physical plan to collect all aggs for a given window operator_name = rel.window().getWindowFuncName(window).lower() try: operation = self.OPERATION_MAPPING[operator_name] except KeyError: # pragma: no cover try: operation = context.schema[context.schema_name].functions[operator_name] except KeyError: # pragma: no cover raise NotImplementedError(f"{operator_name} not (yet) implemented") logger.debug(f"Executing {operator_name} on {str(LoggableDataFrame(df))}") # TODO: can be optimized by re-using already present columns temporary_operand_columns = { new_temporary_column(df): RexConverter.convert(rel, o, dc, context=context) for o in rel.window().getArgs(window) } df = df.assign(**temporary_operand_columns) temporary_operand_columns = list(temporary_operand_columns.keys()) operations.append( (operation, new_temporary_column(df), temporary_operand_columns) ) return operations, df ================================================ FILE: dask_sql/physical/rex/__init__.py ================================================ from .convert import RexConverter ================================================ FILE: dask_sql/physical/rex/base.py ================================================ import logging from typing import TYPE_CHECKING, Any, Union import dask.dataframe as dd import dask_sql from dask_sql.datacontainer import DataContainer if TYPE_CHECKING: from dask_sql._datafusion_lib import Expression, LogicalPlan logger = logging.getLogger(__name__) class BaseRexPlugin: """ Base class for all plugins to convert between a RexNode to a python expression (dask dataframe column or raw value). Derived classed needs to override the class_name attribute and the convert method. """ class_name = None def convert( self, rel: "LogicalPlan", rex: "Expression", dc: DataContainer, context: "dask_sql.Context", ) -> Union[dd.Series, Any]: """Base method to implement""" raise NotImplementedError ================================================ FILE: dask_sql/physical/rex/convert.py ================================================ import logging from typing import TYPE_CHECKING, Any, Union import dask.dataframe as dd from dask_sql.datacontainer import DataContainer from dask_sql.physical.rex.base import BaseRexPlugin from dask_sql.utils import LoggableDataFrame, Pluggable if TYPE_CHECKING: import dask_sql from dask_sql._datafusion_lib import Expression, LogicalPlan logger = logging.getLogger(__name__) _REX_TYPE_TO_PLUGIN = { "RexType.Reference": "InputRef", "RexType.Call": "RexCall", "RexType.Literal": "RexLiteral", "RexType.Alias": "RexAlias", "RexType.ScalarSubquery": "ScalarSubquery", } class RexConverter(Pluggable): """ Helper to convert from rex to a python expression This class stores plugins which can convert from RexNodes to python expression (single values or dask dataframe columns). The stored plugins are assumed to have a class attribute "class_name" to control, which java classes they can convert and they are expected to have a convert (instance) method in the form def convert(self, rex, df) to do the actual conversion. """ @classmethod def add_plugin_class(cls, plugin_class: BaseRexPlugin, replace=True): """Convenience function to add a class directly to the plugins""" logger.debug(f"Registering REX plugin for {plugin_class.class_name}") cls.add_plugin(plugin_class.class_name, plugin_class(), replace=replace) @classmethod def convert( cls, rel: "LogicalPlan", rex: "Expression", dc: DataContainer, context: "dask_sql.Context", ) -> Union[dd.DataFrame, Any]: """ Convert the given Expression into a python expression (a dask dataframe) using the stored plugins and the dictionary of registered dask tables. """ expr_type = _REX_TYPE_TO_PLUGIN[str(rex.getRexType())] try: plugin_instance = cls.get_plugin(expr_type) except KeyError: # pragma: no cover raise NotImplementedError( f"No conversion for class {expr_type} available (yet)." ) logger.debug( f"Processing REX {rex} using {plugin_instance.__class__.__name__}..." ) df = plugin_instance.convert(rel, rex, dc, context=context) logger.debug(f"Processed REX {rex} into {LoggableDataFrame(df)}") return df ================================================ FILE: dask_sql/physical/rex/core/__init__.py ================================================ from .alias import RexAliasPlugin from .call import RexCallPlugin from .input_ref import RexInputRefPlugin from .literal import RexLiteralPlugin from .subquery import RexScalarSubqueryPlugin __all__ = [ RexAliasPlugin, RexCallPlugin, RexInputRefPlugin, RexLiteralPlugin, RexScalarSubqueryPlugin, ] ================================================ FILE: dask_sql/physical/rex/core/alias.py ================================================ from typing import TYPE_CHECKING, Any, Union import dask.dataframe as dd from dask_sql.datacontainer import DataContainer from dask_sql.physical.rex import RexConverter from dask_sql.physical.rex.base import BaseRexPlugin if TYPE_CHECKING: import dask_sql from dask_sql._datafusion_lib import Expression, LogicalPlan class RexAliasPlugin(BaseRexPlugin): """ A RexAliasPlugin is an expression, which references a Subquery. This plugin is thin on logic, however keeping with previous patterns we use the plugin approach instead of placing the logic inline """ class_name = "RexAlias" def convert( self, rel: "LogicalPlan", rex: "Expression", dc: DataContainer, context: "dask_sql.Context", ) -> Union[dd.Series, Any]: # extract the operands; there should only be a single underlying Expression operands = rex.getOperands() assert len(operands) == 1 sub_rex = operands[0] value = RexConverter.convert(rel, sub_rex, dc, context=context) if isinstance(value, DataContainer): return value.df return value ================================================ FILE: dask_sql/physical/rex/core/call.py ================================================ import logging import operator import re import warnings from datetime import datetime from functools import partial, reduce from typing import TYPE_CHECKING, Any, Callable, Union import dask.array as da import dask.config as dask_config import dask.dataframe as dd import numpy as np import pandas as pd from dask.utils import random_state_data from dask_sql._datafusion_lib import SqlTypeName from dask_sql.datacontainer import DataContainer from dask_sql.mappings import ( cast_column_to_type, sql_to_python_type, sql_to_python_value, ) from dask_sql.physical.rel import RelConverter from dask_sql.physical.rex import RexConverter from dask_sql.physical.rex.base import BaseRexPlugin from dask_sql.physical.rex.core.literal import SargPythonImplementation from dask_sql.utils import ( LoggableDataFrame, convert_to_datetime, is_cudf_type, is_datetime, is_frame, ) if TYPE_CHECKING: import dask_sql from dask_sql._datafusion_lib import Expression, LogicalPlan logger = logging.getLogger(__name__) SeriesOrScalar = Union[dd.Series, Any] def as_timelike(op): if isinstance(op, np.int64): return np.timedelta64(op, "D") elif isinstance(op, str): try: return np.datetime64(op) except ValueError: op = datetime.strptime(op, "%Y-%m-%d") return np.datetime64(op.strftime("%Y-%m-%d")) elif pd.api.types.is_datetime64_dtype(op) or isinstance(op, np.timedelta64): return op else: raise ValueError(f"Don't know how to make {type(op)} timelike") class Operation: """Helper wrapper around a function, which is used as operator""" # True, if the operation should also get the dataframe passed needs_dc = False # True, if the operation should also get the REX needs_rex = False # True, if the operation should also needs the Context, possible subquery Relation expansion needs_context = False # True, if the operation needs the original relation algebra needs_rel = False @staticmethod def op_needs_dc(op): return hasattr(op, "needs_dc") and op.needs_dc @staticmethod def op_needs_rex(op): return hasattr(op, "needs_rex") and op.needs_rex @staticmethod def op_needs_context(op): return hasattr(op, "needs_context") and op.needs_context @staticmethod def op_needs_rel(op): return hasattr(op, "needs_rel") and op.needs_rel def __init__(self, f: Callable): """Init with the given function""" self.f = f def __call__(self, *operands, **kwargs) -> SeriesOrScalar: """Call the stored function""" return self.f(*operands, **kwargs) def of(self, op: "Operation") -> "Operation": """Functional composition""" new_op = Operation(lambda *x, **kwargs: self(op(*x, **kwargs))) new_op.needs_dc = Operation.op_needs_dc(op) new_op.needs_rex = Operation.op_needs_rex(op) new_op.needs_context = Operation.op_needs_context(op) new_op.needs_rel = Operation.op_needs_rel(op) return new_op class PredicateBasedOperation(Operation): """ Helper operation to call a function on the input, depending if the first arg evaluates, given a predicate function, to true or false """ def __init__( self, predicate: Callable, true_route: Callable, false_route: Callable ): super().__init__(self.apply) self.predicate = predicate self.true_route = true_route self.false_route = false_route def apply(self, *operands, **kwargs): if self.predicate(operands[0]): return self.true_route(*operands, **kwargs) return self.false_route(*operands, **kwargs) class TensorScalarOperation(PredicateBasedOperation): """ Helper operation to call a function on the input, depending if the first is a dataframe or not """ def __init__(self, tensor_f: Callable, scalar_f: Callable = None): """Init with the given operation""" super().__init__(is_frame, tensor_f, scalar_f) class ReduceOperation(Operation): """Special operator, which is executed by reducing an operation over the input""" def __init__(self, operation: Callable, unary_operation: Callable = None): self.operation = operation self.unary_operation = unary_operation or operation self.needs_dc = Operation.op_needs_dc(self.operation) self.needs_rex = Operation.op_needs_rex(self.operation) super().__init__(self.reduce) def reduce(self, *operands, **kwargs): if len(operands) > 1: if any( map( lambda op: is_frame(op) & pd.api.types.is_datetime64_dtype(op), operands, ) ): operands = tuple(map(as_timelike, operands)) return reduce(partial(self.operation, **kwargs), operands) else: return self.unary_operation(*operands, **kwargs) class SQLDivisionOperator(Operation): """ Division is handled differently in SQL and python. In python3, it will always preserve the full information, even if starting with an integer (so 1/2 = 0.5). In SQL, integer division will return an integer again. However, it is not floor division (where -1/2 = -1), but truncated division (so -1 / 2 = 0). """ needs_rex = True def __init__(self): super().__init__(self.div) def div(self, lhs, rhs, rex=None): result = lhs / rhs output_type = str(rex.getType()) output_type = sql_to_python_type(SqlTypeName.fromString(output_type.upper())) is_float = pd.api.types.is_float_dtype(output_type) if not is_float: result = da.trunc(result) return result class IntDivisionOperator(Operation): """ Truncated integer division (so -1 / 2 = 0). This is only used for internal calculations, which are created by Calcite. """ def __init__(self): super().__init__(self.div) def div(self, lhs, rhs): result = lhs / rhs # Specialized code for literals like "1000µs" # For some reasons, Calcite decides to represent # 1000µs as 1000µs * 1000 / 1000 # We do not need to truncate in this case # So far, I did not spot any other occurrence # of this function. if isinstance(result, np.timedelta64): return result else: return da.trunc(result).astype(np.int64) class CaseOperation(Operation): """The case operator (basically an if then else)""" def __init__(self): super().__init__(self.case) def case(self, *operands) -> SeriesOrScalar: """ Returns `then` where `where`, else `other`. """ assert operands where = operands[0] then = operands[1] if len(operands) > 3: other = self.case(*operands[2:]) elif len(operands) == 2: # CASE/WHEN statement without an ELSE other = None else: other = operands[2] if is_frame(then): return then.where(where, other=other) elif is_frame(other): return other.where(~where, other=then) elif is_frame(where): # This one is a bit tricky. # Everything except "where" are scalars. # To make the "df.where" function still usable # we create a temporary dataframe with the # properties of where (but the content of then). tmp = where.apply(lambda x: then, meta=(where.name, type(then))) return tmp.where(where, other=other) else: return then if where else other class CastOperation(Operation): """The cast operator""" needs_rex = True def __init__(self): super().__init__(self.cast) def cast(self, operand, rex=None) -> SeriesOrScalar: output_type = rex.getType() sql_type = SqlTypeName.fromString(output_type) sql_type_args = () # decimal datatypes require precision and scale if output_type == "DECIMAL": sql_type_args = rex.getPrecisionScale() if output_type == "TIMESTAMP" and pd.api.types.is_integer_dtype(operand): operand = operand * 10**9 if not is_frame(operand): # pragma: no cover return sql_to_python_value(sql_type, operand) python_type = sql_to_python_type(sql_type, *sql_type_args) return_column = cast_column_to_type(operand, python_type) if return_column is None: return_column = operand # TODO: ideally we don't want to directly access the datetimes, # but Pandas can't truncate timezone datetimes and cuDF can't # truncate datetimes if output_type == "DATE": return return_column.dt.floor("D").astype(python_type) return return_column class IsFalseOperation(Operation): """The is false operator""" def __init__(self): super().__init__(self.false_) def false_( self, df: SeriesOrScalar, ) -> SeriesOrScalar: """ Returns true where `df` is false (where `df` can also be just a scalar). Returns false on nan. """ if is_frame(df): return ~df.astype("boolean").fillna(True) return not pd.isna(df) and df is not None and not np.isnan(df) and not bool(df) class IsTrueOperation(Operation): """The is true operator""" def __init__(self): super().__init__(self.true_) def true_( self, df: SeriesOrScalar, ) -> SeriesOrScalar: """ Returns true where `df` is true (where `df` can also be just a scalar). Returns false on nan. """ if is_frame(df): return df.astype("boolean").fillna(False) return not pd.isna(df) and df is not None and not np.isnan(df) and bool(df) class NegativeOperation(Operation): """The negative operator""" def __init__(self): super().__init__(self.negative_) def negative_( self, df: SeriesOrScalar, ) -> SeriesOrScalar: return -df class NotOperation(Operation): """The not operator""" def __init__(self): super().__init__(self.not_) def not_( self, df: SeriesOrScalar, ) -> SeriesOrScalar: """ Returns not `df` (where `df` can also be just a scalar). """ if is_frame(df): return ~(df.astype("boolean")) else: return not df class IsNullOperation(Operation): """The is null operator""" def __init__(self): super().__init__(self.null) def null( self, df: SeriesOrScalar, ) -> SeriesOrScalar: """ Returns true where `df` is null (where `df` can also be just a scalar). """ if is_frame(df): return df.isna() return pd.isna(df) or df is None or np.isnan(df) class IsNotDistinctOperation(Operation): """The is not distinct operator""" def __init__(self): super().__init__(self.not_distinct) def not_distinct(self, lhs: SeriesOrScalar, rhs: SeriesOrScalar) -> SeriesOrScalar: """ Returns true where `lhs` is not distinct from `rhs` (or both are null). """ is_null = IsNullOperation() return (is_null(lhs) & is_null(rhs)) | (lhs == rhs) class RegexOperation(Operation): """An abstract regex operation, which transforms the SQL regex into something python can understand""" needs_rex = True def __init__(self): super().__init__(self.regex) def regex(self, test: SeriesOrScalar, regex: str, rex=None) -> SeriesOrScalar: """ Returns true, if the string test matches the given regex (maybe escaped by escape) """ escape = rex.getEscapeChar() if rex else None if not escape: escape = "\\" # Unfortunately, SQL's like syntax is not directly # a regular expression. We need to do some translation # SQL knows about the following wildcards: # %, ?, [], _, # transformed_regex = "" escaped = False in_char_range = False for char in regex: # Escape characters with "\" if escaped: char = "\\" + char escaped = False # Keep character ranges [...] as they are elif in_char_range: if char == "]": in_char_range = False # These chars have a special meaning in regex # whereas in SQL they have not, so we need to # add additional escaping elif char in self.replacement_chars: char = "\\" + char elif char == "[": in_char_range = True # The needed "\" is printed above, so we continue elif char == escape: escaped = True continue # An unescaped "%" in SQL is a .* elif char == "%": char = ".*" # An unescaped "_" in SQL is a . elif char == "_": char = "." transformed_regex += char # the SQL like always goes over the full string transformed_regex = "^" + transformed_regex + "$" # Finally, apply the string flags = re.DOTALL | re.IGNORECASE if not self.case_sensitive else re.DOTALL if is_frame(test): return test.str.match(transformed_regex, flags=flags).astype("boolean") else: return bool(re.match(transformed_regex, test, flags=flags)) class LikeOperation(RegexOperation): def __init__(self, case_sensitive: bool = True): self.case_sensitive = case_sensitive self.replacement_chars = [ "#", "$", "^", ".", "|", "~", "-", "+", "*", "?", "(", ")", "{", "}", "[", "]", ] super().__init__() class SimilarOperation(RegexOperation): replacement_chars = [ "#", "$", "^", ".", "~", "-", ] case_sensitive = True class PositionOperation(Operation): """The position operator (get the position of a string)""" def __init__(self): super().__init__(self.position) def position(self, search, s, start=None): """Attention: SQL starts counting at 1""" if is_frame(s): s = s.str if start is None or start <= 0: start = 0 else: start -= 1 return s.find(search, start) + 1 class SubStringOperation(Operation): """The substring operator (get a slice of a string)""" def __init__(self): super().__init__(self.substring) def substring(self, s, start, length=None): """Attention: SQL starts counting at 1""" if start <= 0: start = 0 else: start -= 1 end = length + start if length else None if is_frame(s): return s.str.slice(start, end) if end: return s[start:end] else: return s[start:] class TrimOperation(Operation): """The trim operator (remove occurrences left and right of a string)""" def __init__(self, flag="BOTH"): self.flag = flag super().__init__(self.trim) def trim(self, s, search): if is_frame(s): s = s.str if self.flag == "LEADING": strip_call = s.lstrip elif self.flag == "TRAILING": strip_call = s.rstrip elif self.flag == "BOTH": strip_call = s.strip else: raise ValueError(f"Trim type {self.flag} not recognized") return strip_call(search) class ReplaceOperation(Operation): """The replace operator (replace occurrences of pattern in a string)""" def __init__(self): super().__init__(self.replace) def replace(self, s, pat, repl): if is_frame(s): s = s.str return s.replace(pat, repl) class OverlayOperation(Operation): """The overlay operator (replace string according to positions)""" def __init__(self): super().__init__(self.overlay) def overlay(self, s, replace, start, length=None): """Attention: SQL starts counting at 1""" if start <= 0: start = 0 else: start -= 1 if length is None: length = len(replace) end = length + start if is_frame(s): return s.str.slice_replace(start, end, replace) s = s[:start] + replace + s[end:] return s class CoalesceOperation(Operation): def __init__(self): super().__init__(self.coalesce) def coalesce(self, *operands): result = None for operand in operands: if is_frame(operand): # Check if frame evaluates to nan or NA if len(operand) == 1 and not operand.isnull().all().compute(): return operand if result is None else result.fillna(operand) else: result = operand if result is None else result.fillna(operand) elif not pd.isna(operand): return operand if result is None else result.fillna(operand) return result class ToTimestampOperation(Operation): def __init__(self): super().__init__(self.to_timestamp) def to_timestamp(self, df, format): default_format = "%Y-%m-%d %H:%M:%S" # Remove double and single quotes from string format = format.replace('"', "") format = format.replace("'", "") # String cases if type(df) == str: return np.datetime64(datetime.strptime(df, format)) elif df.dtype == "object": return dd.to_datetime(df, format=format) # Integer cases elif np.isscalar(df): if format != default_format: raise RuntimeError("Integer input does not accept a format argument") return np.datetime64(int(df), "s") else: if format != default_format: raise RuntimeError("Integer input does not accept a format argument") return dd.to_datetime(df, unit="s") class YearOperation(Operation): def __init__(self): super().__init__(self.extract_year) def extract_year(self, df: SeriesOrScalar): df = convert_to_datetime(df) return df.year class TimeStampAddOperation(Operation): def __init__(self): super().__init__(self.timestampadd) def timestampadd(self, unit, interval, df: SeriesOrScalar): unit = unit.upper() interval = int(interval) if interval < 0: raise RuntimeError(f"Negative time interval {interval} is not supported.") df = ( df.astype("datetime64[s]") if pd.api.types.is_integer_dtype(df) else df.astype("datetime64[ns]") ) if is_cudf_type(df): from cudf import DateOffset else: from pandas.tseries.offsets import DateOffset if unit in {"YEAR", "YEARS"}: return df + DateOffset(years=interval) elif unit in {"QUARTER", "QUARTERS", "MONTH", "MONTHS"}: if unit in {"QUARTER", "QUARTERS"}: return df + DateOffset(months=interval * 3) else: # "MONTH" return df + DateOffset(months=interval) elif unit in {"WEEK", "WEEKS", "SQL_TSI_WEEK"}: return df + DateOffset(weeks=interval) elif unit in {"DAY", "DAYS", "SQL_TSI_DAY"}: return df + DateOffset(days=interval) elif unit in {"HOUR", "HOURS", "SQL_TSI_HOUR"}: return df + DateOffset(hours=interval) elif unit in {"MINUTE", "MINUTES", "SQL_TSI_MINUTE"}: return df + DateOffset(minutes=interval) elif unit in {"SECOND", "SECONDS", "SQL_TSI_SECOND"}: return df + DateOffset(seconds=interval) elif unit in {"MILLISECOND", "MILLISECONDS"}: return df + DateOffset(milliseconds=interval) elif unit in {"MICROSECOND", "MICROSECONDS"}: return df + DateOffset(microseconds=interval) else: raise NotImplementedError( f"Timestamp addition with {unit} is not supported." ) class DatetimeSubOperation(Operation): """ Datetime subtraction is a special case of the `minus` operation which also specifies a sql interval return type for the operation. """ def __init__(self): super().__init__(self.datetime_sub) def datetime_sub(self, unit, df1, df2): if pd.api.types.is_integer_dtype(df1): df1 = df1 * 10**9 if pd.api.types.is_integer_dtype(df2): df2 = df2 * 10**9 if "datetime64[s]" == str(getattr(df1, "dtype", "")): df1 = df1.astype("datetime64[ns]") if "datetime64[s]" == str(getattr(df2, "dtype", "")): df2 = df2.astype("datetime64[ns]") subtraction_op = ReduceOperation( operation=operator.sub, unary_operation=lambda x: -x ) result = subtraction_op(df2, df1) if is_cudf_type(df1): result = result.astype("int") if unit in {"NANOSECOND", "NANOSECONDS"}: return result elif unit in {"MICROSECOND", "MICROSECONDS"}: return result // 1_000 elif unit in {"SECOND", "SECONDS"}: return result // 1_000_000_000 elif unit in {"MINUTE", "MINUTES"}: return (result / 1_000_000_000) // 60 elif unit in {"HOUR", "HOURS"}: return (result / 1_000_000_000) // 3600 elif unit in {"DAY", "DAYS"}: return ((result / 1_000_000_000) / 3600) // 24 elif unit in {"WEEK", "WEEKS"}: return (((result / 1_000_000_000) / 3600) / 24) // 7 elif unit in {"MONTH", "MONTHS"}: day_result = ((result / 1_000_000_000) / 3600) // 24 avg_days_in_month = ((30 * 4) + 28 + (31 * 7)) / 12 return day_result / avg_days_in_month elif unit in {"QUARTER", "QUARTERS"}: day_result = ((result / 1_000_000_000) / 3600) // 24 avg_days_in_quarter = 3 * ((30 * 4) + 28 + (31 * 7)) / 12 return day_result / avg_days_in_quarter elif unit in {"YEAR", "YEARS"}: return (((result / 1_000_000_000) / 3600) / 24) // 365 else: raise NotImplementedError( f"Timestamp difference with {unit} is not supported." ) class CeilFloorOperation(PredicateBasedOperation): """ Apply ceil/floor operations on a series depending on its dtype (datetime like vs normal) """ def __init__(self, round_method: str): assert round_method in { "ceil", "floor", }, "Round method can only be either ceil or floor" super().__init__( is_datetime, # if the series is dt type self._round_datetime, getattr(da, round_method), ) self.round_method = round_method def _round_datetime(self, *operands): df, unit = operands df = convert_to_datetime(df) unit_map = { "DAY": "D", "HOUR": "h", "MINUTE": "min", "SECOND": "s", "MICROSECOND": "U", "MILLISECOND": "ms", } try: freq = unit_map[unit.upper()] return getattr(df, self.round_method)(freq) except KeyError: raise NotImplementedError( f"{self.round_method} TO {unit} is not (yet) implemented." ) class BaseRandomOperation(Operation): """ Return a random number (specified by the given function) with the random number generator set to the given seed. As we need to know how many random numbers we should generate, we also get the current dataframe as input and use it to create random numbers for each partition separately. To make this deterministic, we use the partition number as additional input to the seed. """ needs_dc = True def random_function(self, partition, random_state, kwargs): """Needs to be implemented in derived classes""" raise NotImplementedError def random_frame(self, seed: int, dc: DataContainer, **kwargs) -> dd.Series: """This function - in contrast to others in this module - will only ever be called on data frames""" df = dc.df state_data = random_state_data(df.npartitions, np.random.RandomState(seed=seed)) def random_partition_func(df, state_data, partition_info=None): """Create a random number for each partition""" partition_index = ( partition_info["number"] if partition_info is not None else 0 ) state = np.random.RandomState(state_data[partition_index]) return self.random_function(df, state, kwargs) random_series = df.map_partitions( random_partition_func, state_data, meta=("random", "float64") ) # This part seems to be stupid, but helps us do a very simple # task without going into the (private) internals of Dask: # copy all meta information from the original input dataframe # This is important so that the returned series looks # exactly like coming from the input dataframe return df.assign(random=random_series)["random"] class RandOperation(BaseRandomOperation): """Create a random number between 0 and 1""" def __init__(self): super().__init__(f=self.rand) def rand(self, seed: int = None, dc: DataContainer = None): return self.random_frame(seed=seed, dc=dc) def random_function(self, partition, random_state, kwargs): return random_state.random_sample(size=len(partition)) class RandIntegerOperation(BaseRandomOperation): """Create a random integer between 0 and high""" def __init__(self): super().__init__(f=self.rand_integer) def rand_integer( self, seed: int = None, high: int = None, dc: DataContainer = None ): # Two possibilities: RAND_INTEGER(seed, high) or RAND_INTEGER(high) if high is None: high = seed seed = None return self.random_frame(seed=seed, high=high, dc=dc) def random_function(self, partition, random_state, kwargs): return random_state.randint(size=len(partition), low=0, **kwargs) class SearchOperation(Operation): """ Search is a special operation in SQL, which allows to write "range-like" conditions, such like (1 < a AND a < 2) OR (4 < a AND a < 6) in a more convenient setting. """ def __init__(self): super().__init__(self.search) def search(self, series: dd.Series, sarg: SargPythonImplementation): conditions = [r.filter_on(series) for r in sarg.ranges] assert len(conditions) > 0 if len(conditions) > 1: or_operation = ReduceOperation(operation=operator.or_) return or_operation(*conditions) else: return conditions[0] class ExtractOperation(Operation): """ Function for performing PostgreSQL like functions in a more convenient setting. """ def __init__(self): super().__init__(self.date_part) def date_part(self, what, df: SeriesOrScalar): what = what.upper() df = convert_to_datetime(df) if what in {"YEAR", "YEARS"}: return df.year elif what in {"CENTURY", "CENTURIES"}: return da.trunc(df.year / 100) elif what in {"DAY", "DAYS"}: return df.day elif what in {"DECADE", "DECADES"}: return da.trunc(df.year / 10) elif what == "DOW": return (df.dayofweek + 1) % 7 elif what == "DOY": return df.dayofyear elif what in {"HOUR", "HOURS"}: return df.hour elif what in {"MICROSECOND", "MICROSECONDS"}: return df.microsecond elif what in {"MILLENIUM", "MILLENIUMS", "MILLENNIUM", "MILLENNIUMS"}: return da.trunc(df.year / 1000) elif what in {"MILLISECOND", "MILLISECONDS"}: return da.trunc(1000 * df.microsecond) elif what in {"MINUTE", "MINUTES"}: return df.minute elif what in {"MONTH", "MONTHS"}: return df.month elif what in {"QUARTER", "QUARTERS"}: return df.quarter elif what in {"SECOND", "SECONDS"}: return df.second elif what in {"WEEK", "WEEKS"}: return df.isocalendar().week elif what in {"YEAR", "YEARS"}: return df.year elif what == "DATE": return ( df.date() if isinstance(df, pd.Timestamp) else dd.to_datetime(df.strftime("%Y-%m-%d")) ) else: raise NotImplementedError(f"Extraction of {what} is not (yet) implemented.") class BetweenOperation(Operation): """ Function for finding rows between two scalar values """ needs_rex = True def __init__(self): super().__init__(self.between) def between(self, series: dd.Series, low, high, rex=None): return ( ~series.between(low, high, inclusive="both") if rex.isNegated() else series.between(low, high, inclusive="both") ) class InListOperation(Operation): """ Returns a boolean of whether an expression is/isn't in a set of values """ needs_rex = True def __init__(self): super().__init__(self.inList) def inList(self, series: dd.Series, *operands, rex=None): result = series.isin(operands) return ~result if rex.isNegated() else result class InSubqueryOperation(Operation): """ Returns a boolean of whether an expression is/isn't in a Subquery Expression result """ needs_rex = True needs_context = True needs_rel = True def __init__(self): super().__init__(self.inSubquery) def inSubquery( self, series: dd.Series, *operands, rel=None, rex=None, context=None ): sub_rel = rex.getSubqueryLogicalPlan() dc = RelConverter.convert(sub_rel, context=context) # Extract the specified column/Series from the Dataframe fq_column_name = rex.column_name(rel).split(".") # FIXME: dask's isin doesn't support dask frames as arguments # so we need to compute here col = dc.df[fq_column_name[-1]].compute() warnings.warn( "Dask doesn't support Dask frames as input for .isin, so we must force an early computation", ResourceWarning, ) return series.isin(col) class RexCallPlugin(BaseRexPlugin): """ RexCall is used for expressions, which calculate something. An example is SELECT a + b FROM ... but also a > 3 Typically, a RexCall has inputs (which can be RexNodes again) and calls a function on these inputs. The inputs can either be a column or a scalar value. """ class_name = "RexCall" OPERATION_MAPPING = { # "binary" functions "between": BetweenOperation(), "and": ReduceOperation(operation=operator.and_), "or": ReduceOperation(operation=operator.or_), ">": ReduceOperation(operation=operator.gt), ">=": ReduceOperation(operation=operator.ge), "<": ReduceOperation(operation=operator.lt), "<=": ReduceOperation(operation=operator.le), "=": ReduceOperation(operation=operator.eq), "!=": ReduceOperation(operation=operator.ne), "<>": ReduceOperation(operation=operator.ne), "+": ReduceOperation(operation=operator.add, unary_operation=lambda x: x), "-": ReduceOperation(operation=operator.sub, unary_operation=lambda x: -x), "/": ReduceOperation(operation=SQLDivisionOperator()), "*": ReduceOperation(operation=operator.mul), "is distinct from": NotOperation().of(IsNotDistinctOperation()), "is not distinct from": IsNotDistinctOperation(), "/int": IntDivisionOperator(), # special operations "cast": CastOperation(), "case": CaseOperation(), "not like": NotOperation().of(LikeOperation(case_sensitive=True)), "like": LikeOperation(case_sensitive=True), "not ilike": NotOperation().of(LikeOperation(case_sensitive=False)), "ilike": LikeOperation(case_sensitive=False), "not similar to": NotOperation().of(SimilarOperation()), "similar to": SimilarOperation(), "negative": NegativeOperation(), "not": NotOperation(), "in list": InListOperation(), "in subquery": InSubqueryOperation(), "is null": IsNullOperation(), "is not null": NotOperation().of(IsNullOperation()), "is true": IsTrueOperation(), "is not true": NotOperation().of(IsTrueOperation()), "is false": IsFalseOperation(), "is not false": NotOperation().of(IsFalseOperation()), "is unknown": IsNullOperation(), "is not unknown": NotOperation().of(IsNullOperation()), "rand": RandOperation(), "random": RandOperation(), "rand_integer": RandIntegerOperation(), "search": SearchOperation(), # Unary math functions "abs": TensorScalarOperation(lambda x: x.abs(), np.abs), "acos": Operation(da.arccos), "asin": Operation(da.arcsin), "atan": Operation(da.arctan), "atan2": Operation(da.arctan2), "cbrt": Operation(da.cbrt), "ceil": CeilFloorOperation("ceil"), "cos": Operation(da.cos), "cot": Operation(lambda x: 1 / da.tan(x)), "degrees": Operation(da.degrees), "exp": Operation(da.exp), "floor": CeilFloorOperation("floor"), "log10": Operation(da.log10), "ln": Operation(da.log), "mod": Operation(da.mod), "power": Operation(da.power), "radians": Operation(da.radians), "round": TensorScalarOperation(lambda x, *ops: x.round(*ops), np.round), "sign": Operation(da.sign), "sin": Operation(da.sin), "tan": Operation(da.tan), "truncate": Operation(da.trunc), # string operations "||": ReduceOperation(operation=operator.add), "concat": ReduceOperation(operation=operator.add), "characterlength": TensorScalarOperation( lambda x: x.str.len(), lambda x: len(x) ), "character_length": TensorScalarOperation( lambda x: x.str.len(), lambda x: len(x) ), "upper": TensorScalarOperation(lambda x: x.str.upper(), lambda x: x.upper()), "lower": TensorScalarOperation(lambda x: x.str.lower(), lambda x: x.lower()), "position": PositionOperation(), "trim": TrimOperation(), "ltrim": TrimOperation("LEADING"), "rtrim": TrimOperation("TRAILING"), "btrim": TrimOperation("BOTH"), "overlay": OverlayOperation(), "substr": SubStringOperation(), "substring": SubStringOperation(), "initcap": TensorScalarOperation(lambda x: x.str.title(), lambda x: x.title()), "coalesce": CoalesceOperation(), "replace": ReplaceOperation(), # date/time operations "extract_date": ExtractOperation(), "localtime": Operation(lambda *args: pd.Timestamp.now()), "localtimestamp": Operation(lambda *args: pd.Timestamp.now()), "current_time": Operation(lambda *args: pd.Timestamp.now()), "current_date": Operation(lambda *args: pd.Timestamp.now()), "current_timestamp": Operation(lambda *args: pd.Timestamp.now()), "last_day": TensorScalarOperation( lambda x: x + pd.tseries.offsets.MonthEnd(1), lambda x: convert_to_datetime(x) + pd.tseries.offsets.MonthEnd(1), ), "dsql_totimestamp": ToTimestampOperation(), # Temporary UDF functions that need to be moved after this POC "datepart": ExtractOperation(), "date_part": ExtractOperation(), "year": YearOperation(), "timestampadd": TimeStampAddOperation(), "timestampceil": CeilFloorOperation("ceil"), "timestampfloor": CeilFloorOperation("floor"), "timestampdiff": DatetimeSubOperation(), } def convert( self, rel: "LogicalPlan", expr: "Expression", dc: DataContainer, context: "dask_sql.Context", ) -> SeriesOrScalar: # Prepare the operands by turning the RexNodes into python expressions operands = [ RexConverter.convert(rel, o, dc, context=context) for o in expr.getOperands() ] # FIXME: cuDF doesn't support binops between decimal columns and numpy ints / floats if dask_config.get("sql.mappings.decimal_support") == "cudf" and any( str(getattr(o, "dtype", None)) == "decimal128" for o in operands ): from decimal import Decimal operands = [ Decimal(str(o)) if isinstance(o, float) else o.item() if np.isscalar(o) and pd.api.types.is_integer_dtype(o) else o for o in operands ] # Now use the operator name in the mapping schema_name = context.schema_name operator_name = expr.getOperatorName().lower() try: operation = self.OPERATION_MAPPING[operator_name] except KeyError: try: operation = context.schema[schema_name].functions[operator_name] except KeyError: # pragma: no cover raise NotImplementedError( f"RexCall operator '{operator_name}' not (yet) implemented" ) logger.debug( f"Executing {operator_name} on {[str(LoggableDataFrame(df)) for df in operands]}" ) kwargs = {} if Operation.op_needs_dc(operation): kwargs["dc"] = dc if Operation.op_needs_rex(operation): kwargs["rex"] = expr if Operation.op_needs_context(operation): kwargs["context"] = context if Operation.op_needs_rel(operation): kwargs["rel"] = rel return operation(*operands, **kwargs) # TODO: We have information on the typing here - we should use it ================================================ FILE: dask_sql/physical/rex/core/input_ref.py ================================================ from typing import TYPE_CHECKING import dask.dataframe as dd from dask_sql.datacontainer import DataContainer from dask_sql.physical.rex.base import BaseRexPlugin if TYPE_CHECKING: import dask_sql from dask_sql._datafusion_lib import Expression, LogicalPlan class RexInputRefPlugin(BaseRexPlugin): """ A RexInputRef is an expression, which references a single column. It is typically to be found in any expressions which calculate a function in a column of a table. """ class_name = "InputRef" def convert( self, rel: "LogicalPlan", rex: "Expression", dc: DataContainer, context: "dask_sql.Context", ) -> dd.Series: df = dc.df cc = dc.column_container # The column is references by index index = rex.getIndex() backend_column_name = cc.get_backend_by_frontend_index(index) return df[backend_column_name] ================================================ FILE: dask_sql/physical/rex/core/literal.py ================================================ import logging from datetime import datetime from typing import TYPE_CHECKING, Any import dask.dataframe as dd import numpy as np from dask_sql._datafusion_lib import SqlTypeName from dask_sql.datacontainer import DataContainer from dask_sql.mappings import sql_to_python_value from dask_sql.physical.rex.base import BaseRexPlugin if TYPE_CHECKING: import dask_sql from dask_sql._datafusion_lib import Expression, LogicalPlan logger = logging.getLogger(__name__) class SargPythonImplementation: """ Apache Calcite comes with a Sarg literal, which stands for the "search arguments" (which are later used in a SEARCH call). We transform it into a more manageable python object by extracting the Java properties. """ class Range: """Helper class to represent one of the ranges in a Sarg object""" # def __init__(self, range: com.google.common.collect.Range, literal_type: str): # self.lower_endpoint = None # self.lower_open = True # if range.hasLowerBound(): # self.lower_endpoint = sql_to_python_value( # literal_type, range.lowerEndpoint() # ) # self.lower_open = ( # range.lowerBoundType() == com.google.common.collect.BoundType.OPEN # ) # self.upper_endpoint = None # self.upper_open = True # if range.hasUpperBound(): # self.upper_endpoint = sql_to_python_value( # literal_type, range.upperEndpoint() # ) # self.upper_open = ( # range.upperBoundType() == com.google.common.collect.BoundType.OPEN # ) def filter_on(self, series: dd.Series): lower_condition = True if self.lower_endpoint is not None: if self.lower_open: lower_condition = self.lower_endpoint < series else: lower_condition = self.lower_endpoint <= series upper_condition = True if self.upper_endpoint is not None: if self.upper_open: upper_condition = self.upper_endpoint > series else: upper_condition = self.upper_endpoint >= series return lower_condition & upper_condition def __repr__(self) -> str: return f"Range {self.lower_endpoint} - {self.upper_endpoint}" # def __init__(self, java_sarg: org.apache.calcite.util.Sarg, literal_type: str): # self.ranges = [ # SargPythonImplementation.Range(r, literal_type) # for r in java_sarg.rangeSet.asRanges() # ] def __repr__(self) -> str: return ",".join(map(str, self.ranges)) class RexLiteralPlugin(BaseRexPlugin): """ A RexLiteral in an expression stands for a bare single value. The task of this class is therefore just to extract this value from the java instance and convert it into the correct python type. It is typically used when specifying a literal in a SQL expression, e.g. in a filter. """ class_name = "RexLiteral" def convert( self, rel: "LogicalPlan", rex: "Expression", dc: DataContainer, context: "dask_sql.Context", ) -> Any: literal_type = str(rex.getType()) # Call the Rust function to get the actual value and convert the Rust # type name back to a SQL type if literal_type == "Boolean": try: literal_type = SqlTypeName.BOOLEAN literal_value = rex.getBoolValue() except TypeError: literal_type = SqlTypeName.NULL literal_value = None elif literal_type == "Float32": literal_type = SqlTypeName.FLOAT literal_value = rex.getFloat32Value() elif literal_type == "Float64": literal_type = SqlTypeName.DOUBLE literal_value = rex.getFloat64Value() elif literal_type == "Decimal128": literal_type = SqlTypeName.DECIMAL value, _, scale = rex.getDecimal128Value() literal_value = value / (10**scale) elif literal_type == "UInt8": literal_type = SqlTypeName.TINYINT literal_value = rex.getUInt8Value() elif literal_type == "UInt16": literal_type = SqlTypeName.SMALLINT literal_value = rex.getUInt16Value() elif literal_type == "UInt32": literal_type = SqlTypeName.INTEGER literal_value = rex.getUInt32Value() elif literal_type == "UInt64": literal_type = SqlTypeName.BIGINT literal_value = rex.getUInt64Value() elif literal_type == "Int8": literal_type = SqlTypeName.TINYINT literal_value = rex.getInt8Value() elif literal_type == "Int16": literal_type = SqlTypeName.SMALLINT literal_value = rex.getInt16Value() elif literal_type == "Int32": literal_type = SqlTypeName.INTEGER literal_value = rex.getInt32Value() elif literal_type == "Int64": literal_type = SqlTypeName.BIGINT literal_value = rex.getInt64Value() elif literal_type == "Utf8": literal_type = SqlTypeName.VARCHAR literal_value = rex.getStringValue() elif literal_type == "Date32": literal_type = SqlTypeName.DATE literal_value = np.datetime64(rex.getDate32Value(), "D") elif literal_type == "Date64": literal_type = SqlTypeName.DATE literal_value = np.datetime64(rex.getDate64Value(), "ms") elif literal_type == "Time64": literal_value = np.datetime64(rex.getTime64Value(), "ns") literal_type = SqlTypeName.TIME elif literal_type == "Null": literal_type = SqlTypeName.NULL literal_value = None elif literal_type == "IntervalDayTime": literal_type = SqlTypeName.INTERVAL_DAY literal_value = rex.getIntervalDayTimeValue() elif literal_type == "IntervalMonthDayNano": literal_type = SqlTypeName.INTERVAL_MONTH_DAY_NANOSECOND literal_value = rex.getIntervalMonthDayNanoValue() elif literal_type in { "TimestampSecond", "TimestampMillisecond", "TimestampMicrosecond", "TimestampNanosecond", }: unit_mapping = { "TimestampSecond": "s", "TimestampMillisecond": "ms", "TimestampMicrosecond": "us", "TimestampNanosecond": "ns", } numpy_unit = unit_mapping.get(literal_type) literal_value, timezone = rex.getTimestampValue() if timezone and timezone != "UTC": raise ValueError("Non UTC timezones not supported") elif timezone is None: literal_value = datetime.fromtimestamp(literal_value // 10**9) literal_value = str(literal_value) literal_type = SqlTypeName.TIMESTAMP literal_value = np.datetime64(literal_value, numpy_unit) else: raise RuntimeError( f"Failed to map literal type {literal_type} to python type in literal.py" ) # if isinstance(literal_value, org.apache.calcite.util.Sarg): # return SargPythonImplementation(literal_value, literal_type) python_value = sql_to_python_value(literal_type, literal_value) logger.debug( f"literal.py python_value: {python_value} or Python type: {type(python_value)}" ) return python_value ================================================ FILE: dask_sql/physical/rex/core/subquery.py ================================================ from typing import TYPE_CHECKING import dask.dataframe as dd from dask_sql.datacontainer import DataContainer from dask_sql.physical.rel import RelConverter from dask_sql.physical.rex.base import BaseRexPlugin if TYPE_CHECKING: import dask_sql from dask_sql._datafusion_lib import Expression, LogicalPlan class RexScalarSubqueryPlugin(BaseRexPlugin): """ A RexScalarSubqueryPlugin is an expression, which references a Subquery. This plugin is thin on logic, however keeping with previous patterns we use the plugin approach instead of placing the logic inline """ class_name = "ScalarSubquery" def convert( self, rel: "LogicalPlan", rex: "Expression", dc: DataContainer, context: "dask_sql.Context", ) -> dd.DataFrame: # Extract the LogicalPlan from the Expr instance sub_rel = rex.getSubqueryLogicalPlan() dc = RelConverter.convert(sub_rel, context=context) return dc.df ================================================ FILE: dask_sql/physical/utils/__init__.py ================================================ ================================================ FILE: dask_sql/physical/utils/filter.py ================================================ from __future__ import annotations import itertools import logging import operator import dask.dataframe as dd import numpy as np from dask.blockwise import Blockwise from dask.highlevelgraph import HighLevelGraph, MaterializedLayer from dask.layers import DataFrameIOLayer from dask.utils import M, apply, is_arraylike logger = logging.getLogger(__name__) def attempt_predicate_pushdown( ddf: dd.DataFrame, preserve_filters: bool = True, extract_filters: bool = True, add_filters: list | tuple | DNF | None = None, ) -> dd.DataFrame: """Use graph information to update IO-level filters The original `ddf` will be returned if/when the predicate-pushdown optimization fails. This is a special optimization that must be called eagerly on a DataFrame collection when filters are applied. The "eager" requirement for this optimization is due to the fact that `npartitions` and `divisions` may change when this optimization is applied (invalidating npartition/divisions-specific logic in following Layers). Parameters ---------- ddf Dask-DataFrame target for predicate pushdown. preserve_filters Whether to preserve pre-existing filters in the case that either `add_filters` is specified, or `extract_filters` is `True` and filters are successfully extracted from `ddf`. Default is `True`. extract_filters Whether to extract filters from the task graph of `ddf`. Default is `True`. add_filters Custom filters to manually add to the IO layer of `ddf`. """ if not (extract_filters or add_filters): # Not extracting filters from the graph or # manually adding user-defined filters. Return return ddf # Check that we have a supported `ddf` object if not isinstance(ddf, dd.DataFrame): raise ValueError( f"Predicate pushdown optimization skipped. Type {type(ddf)} " f"does not support predicate pushdown." ) elif not isinstance(ddf.dask, HighLevelGraph): logger.warning( f"Predicate pushdown optimization skipped. Graph must be " f"a HighLevelGraph object (got {type(ddf.dask)})." ) return ddf # We were able to extract a DNF filter expression. # Check that we have a single IO layer with `filters` support io_layer = [] for k, v in ddf.dask.layers.items(): if isinstance(v, DataFrameIOLayer): io_layer.append(k) creation_info = ( (v.creation_info or {}) if hasattr(v, "creation_info") else {} ) if "filters" not in creation_info.get("kwargs", {}): # No filters support return ddf if len(io_layer) != 1: # Not a single IO layer return ddf io_layer = io_layer.pop() # Get pre-existing filters existing_filters = ( ddf.dask.layers[io_layer].creation_info.get("kwargs", {}).get("filters") ) # Start by converting the HLG to a `RegenerableGraph`. # Succeeding here means that all layers in the graph # are regenerable. try: dsk = RegenerableGraph.from_hlg(ddf.dask) except (ValueError, TypeError): logger.warning( "Predicate pushdown optimization skipped. One or more " "layers in the HighLevelGraph was not 'regenerable'." ) return ddf name = ddf._name extracted_filters = DNF(None) if extract_filters: # Extract a DNF-formatted filter expression try: extracted_filters = dsk.layers[name]._dnf_filter_expression(dsk) except (ValueError, TypeError): # DNF dispatching failed for 1+ layers logger.warning( "Predicate pushdown optimization skipped. One or more " "layers has an unknown filter expression." ) # Combine filters filters = DNF(None) if preserve_filters: filters = filters.combine(existing_filters) if extract_filters: filters = filters.combine(extracted_filters) if add_filters: filters = filters.combine(add_filters) if not filters: # No filters encountered return ddf filters = filters.to_list_tuple() # FIXME: pyarrow doesn't seem to like converting datetime64[D] to scalars # so we must convert any we encounter to datetime64[ns] filters = [ [ ( col, op, val.astype("datetime64[ns]") if isinstance(val, np.datetime64) and val.dtype == "datetime64[D]" else val, ) for col, op, val in sublist ] for sublist in filters ] # Regenerate collection with filtered IO layer try: _regen_cache = {} return dsk.layers[name]._regenerate_collection( dsk, # TODO: shouldn't need to specify index=False after dask#9661 is merged new_kwargs={io_layer: {"filters": filters, "index": False}}, _regen_cache=_regen_cache, ) except ValueError as err: # Most-likely failed to apply filters in read_parquet. # We can just bail on predicate pushdown, but we also # raise a warning to encourage the user to file an issue. logger.warning( f"Predicate pushdown failed to apply filters: {filters}. " f"Please open a bug report at " f"https://github.com/dask-contrib/dask-sql/issues/new/choose " f"and include the following error message: {err}" ) return ddf class DNF: """Manage filters in Disjunctive Normal Form (DNF)""" class _Or(frozenset): """Fozen set of disjunctions""" def to_list_tuple(self) -> list: # DNF "or" is List[List[Tuple]] def _maybe_list(val): if isinstance(val, tuple) and val and isinstance(val[0], (tuple, list)): return list(val) return [val] return [ _maybe_list(val.to_list_tuple()) if hasattr(val, "to_list_tuple") else _maybe_list(val) for val in self ] class _And(frozenset): """Frozen set of conjunctions""" def to_list_tuple(self) -> list: # DNF "and" is List[Tuple] return tuple( val.to_list_tuple() if hasattr(val, "to_list_tuple") else val for val in self ) _filters: _And | _Or | None # Underlying filter expression def __init__(self, filters: DNF | _And | _Or | list | tuple | None) -> DNF: if isinstance(filters, DNF): self._filters = filters._filters else: self._filters = self.normalize(filters) def to_list_tuple(self) -> list: return self._filters.to_list_tuple() def __bool__(self) -> bool: return bool(self._filters) @classmethod def normalize(cls, filters: _And | _Or | list | tuple | None): """Convert raw filters to the `_Or(_And)` DNF representation""" def _valid_tuple(predicate: tuple): col, op, val = predicate if isinstance(col, tuple): raise TypeError("filters must be List[Tuple] or List[List[Tuple]]") if op in ("in", "not in"): return (col, op, tuple(val)) else: return predicate def _valid_list(conjunction: list): valid = [] for predicate in conjunction: if not isinstance(predicate, tuple): raise TypeError(f"Predicate must be a tuple, got {predicate}") valid.append(_valid_tuple(predicate)) return valid if not filters: result = None elif isinstance(filters, list): conjunctions = filters if isinstance(filters[0], list) else [filters] result = cls._Or( [cls._And(_valid_list(conjunction)) for conjunction in conjunctions] ) elif isinstance(filters, tuple): result = cls._Or((cls._And((_valid_tuple(filters),)),)) elif isinstance(filters, cls._Or): result = cls._Or(se for e in filters for se in cls.normalize(e)) elif isinstance(filters, cls._And): total = [] for c in itertools.product(*[cls.normalize(e) for e in filters]): total.append(cls._And(se for e in c for se in e)) result = cls._Or(total) else: raise TypeError(f"{type(filters)} not a supported type for DNF") return result def combine(self, other: DNF | _And | _Or | list | tuple | None) -> DNF: """Combine with another DNF object""" if not isinstance(other, DNF): other = DNF(other) assert isinstance(other, DNF) if self._filters is None: result = other._filters elif other._filters is None: result = self._filters else: result = self._And([self._filters, other._filters]) return DNF(result) # Define all supported comparison functions # (and their mapping to a string expression) _comparison_symbols = { operator.eq: "==", operator.ne: "!=", operator.lt: "<", operator.le: "<=", operator.gt: ">", operator.ge: ">=", np.greater: ">", np.greater_equal: ">=", np.less: "<", np.less_equal: "<=", np.equal: "==", np.not_equal: "!=", } # Define all regenerable "pass-through" ops # that do not affect filters. _pass_through_ops = {M.fillna, M.astype} # Define set of all "regenerable" operations. # Predicate pushdown is supported for graphs # comprised of `Blockwise` layers based on these # operations _regenerable_ops = ( set(_comparison_symbols.keys()) | { operator.and_, operator.or_, operator.getitem, operator.inv, M.isin, M.isna, } | _pass_through_ops ) # Specify functions that must be generated with # a different API at the dataframe-collection level _special_op_mappings = { M.fillna: dd.DataFrame.fillna, M.isin: dd.DataFrame.isin, M.isna: dd.DataFrame.isna, M.astype: dd.DataFrame.astype, } # Convert _pass_through_ops to respect "special" mappings _pass_through_ops = {_special_op_mappings.get(op, op) for op in _pass_through_ops} def _preprocess_layers(input_layers): # NOTE: This is a Layer-specific work-around to deal with # the fact that `dd.DataFrame.isin(values)` will add a distinct # `MaterializedLayer` for the `values` argument. # See: https://github.com/dask-contrib/dask-sql/issues/607 skip = set() layers = input_layers.copy() for key, layer in layers.items(): if key.startswith("isin-") and isinstance(layer, Blockwise): indices = list(layer.indices) for i, (k, ind) in enumerate(layer.indices): if ( ind is None and isinstance(layers.get(k), MaterializedLayer) and isinstance(layers[k].get(k), (np.ndarray, tuple)) ): # Replace `indices[i]` with a literal value and # make sure we skip the `MaterializedLayer` that # we are now fusing into the `isin` value = layers[k][k] value = value[0](*value[1:]) if callable(value[0]) else value indices[i] = (value, None) skip.add(k) layer.indices = tuple(indices) return {k: v for k, v in layers.items() if k not in skip} class RegenerableLayer: """Regenerable Layer Wraps ``dask.highlevelgraph.Blockwise`` to ensure that a ``creation_info`` attribute is defined. This class also defines the necessary methods for recursive layer regeneration and filter-expression generation. """ def __init__(self, layer, creation_info): self.layer = layer # Original Blockwise layer reference self.creation_info = creation_info def _regenerate_collection( self, dsk, new_kwargs: dict = None, _regen_cache: dict = None, ): """Regenerate a Dask collection for this layer using the provided inputs and key-word arguments """ # Return regenerated layer if the work was # already done if _regen_cache is None: _regen_cache = {} if self.layer.output in _regen_cache: return _regen_cache[self.layer.output] # Recursively generate necessary inputs to # this layer to generate the collection inputs = [] for key, ind in self.layer.indices: if ind is None: if isinstance(key, (str, tuple)) and key in dsk.layers: continue inputs.append(key) elif key in self.layer.io_deps: continue else: inputs.append( dsk.layers[key]._regenerate_collection( dsk, new_kwargs=new_kwargs, _regen_cache=_regen_cache, ) ) # Extract the callable func and key-word args. # Then return a regenerated collection func = self.creation_info.get("func", None) if func is None: raise ValueError( "`_regenerate_collection` failed. " "Not all HLG layers are regenerable." ) regen_args = self.creation_info.get("args", []) regen_kwargs = self.creation_info.get("kwargs", {}).copy() regen_kwargs = {k: v for k, v in self.creation_info.get("kwargs", {}).items()} regen_kwargs.update((new_kwargs or {}).get(self.layer.output, {})) result = func(*inputs, *regen_args, **regen_kwargs) _regen_cache[self.layer.output] = result return result def _dnf_filter_expression(self, dsk): """Return a DNF-formatted filter expression for the graph terminating at this layer """ op = self.creation_info["func"] if op in _comparison_symbols.keys(): func = _blockwise_comparison_dnf elif op in (operator.and_, operator.or_): func = _blockwise_logical_dnf elif op == operator.getitem: func = _blockwise_getitem_dnf elif op == dd.DataFrame.isin: func = _blockwise_isin_dnf elif op == dd.DataFrame.isna: func = _blockwise_isna_dnf elif op == operator.inv: func = _blockwise_inv_dnf elif op in _pass_through_ops: func = _blockwise_pass_through_dnf else: raise ValueError(f"No DNF expression for {op}") return func(op, self.layer.indices, dsk) class RegenerableGraph: """Regenerable Graph This class is similar to ``dask.highlevelgraph.HighLevelGraph``. However, all layers in a ``RegenerableGraph`` graph must be ``RegenerableLayer`` objects (which wrap ``Blockwise`` layers). """ def __init__(self, layers: dict): self.layers = layers @classmethod def from_hlg(cls, hlg: HighLevelGraph): """Construct a ``RegenerableGraph`` from a ``HighLevelGraph``""" if not isinstance(hlg, HighLevelGraph): raise TypeError(f"Expected HighLevelGraph, got {type(hlg)}") _layers = {} for key, layer in _preprocess_layers(hlg.layers).items(): regenerable_layer = None if isinstance(layer, DataFrameIOLayer): regenerable_layer = RegenerableLayer(layer, layer.creation_info or {}) elif isinstance(layer, Blockwise): tasks = list(layer.dsk.values()) if len(tasks) == 1 and tasks[0]: kwargs = {} if tasks[0][0] == apply: op = tasks[0][1] options = tasks[0][3] if isinstance(options, dict): kwargs = options elif ( isinstance(options, tuple) and options and callable(options[0]) ): kwargs = options[0](*options[1:]) else: op = tasks[0][0] if op in _regenerable_ops: regenerable_layer = RegenerableLayer( layer, { "func": _special_op_mappings.get(op, op), "kwargs": kwargs, }, ) if regenerable_layer is None: raise ValueError(f"Graph contains non-regenerable layer: {layer}") _layers[key] = regenerable_layer return RegenerableGraph(_layers) def _get_blockwise_input(input_index, indices: list, dsk: RegenerableGraph): # Simple utility to get the required input expressions # for a Blockwise layer (using indices) key = indices[input_index][0] if indices[input_index][1] is None: return key return dsk.layers[key]._dnf_filter_expression(dsk) def _inv(symbol: str): return { ">": "<", "<": ">", ">=": "<=", "<=": ">=", "in": "not in", "not in": "in", "is": "is not", "is not": "is", }.get(symbol, symbol) def _blockwise_comparison_dnf(op, indices: list, dsk: RegenerableGraph) -> DNF: # Return DNF expression pattern for a simple comparison left = _get_blockwise_input(0, indices, dsk) right = _get_blockwise_input(1, indices, dsk) if is_arraylike(left) and hasattr(left, "item") and left.size == 1: left = left.item() # Need inverse comparison in read_parquet return DNF((right, _inv(_comparison_symbols[op]), left)) if is_arraylike(right) and hasattr(right, "item") and right.size == 1: right = right.item() return DNF((left, _comparison_symbols[op], right)) def _blockwise_logical_dnf(op, indices: list, dsk: RegenerableGraph) -> DNF: # Return DNF expression pattern for logical "and" or "or" left = _get_blockwise_input(0, indices, dsk) right = _get_blockwise_input(1, indices, dsk) filters = [] for val in [left, right]: if not isinstance(val, (tuple, DNF)): raise TypeError(f"Invalid logical operand: {val}") filters.append(DNF(val)._filters) if op == operator.or_: return DNF(DNF._Or(filters)) elif op == operator.and_: return DNF(DNF._And(filters)) else: raise ValueError def _blockwise_getitem_dnf(op, indices: list, dsk: RegenerableGraph): # Return dnf of key (selected by getitem) key = _get_blockwise_input(1, indices, dsk) return key def _blockwise_pass_through_dnf(op, indices: list, dsk: RegenerableGraph): # Return dnf of input collection return _get_blockwise_input(0, indices, dsk) def _blockwise_isin_dnf(op, indices: list, dsk: RegenerableGraph) -> DNF: # Return DNF expression pattern for a simple "in" comparison left = _get_blockwise_input(0, indices, dsk) right = _get_blockwise_input(1, indices, dsk) return DNF((left, "in", tuple(right))) def _blockwise_isna_dnf(op, indices: list, dsk: RegenerableGraph) -> DNF: # Return DNF expression pattern for `isna` left = _get_blockwise_input(0, indices, dsk) return DNF((left, "is", None)) def _blockwise_inv_dnf(op, indices: list, dsk: RegenerableGraph) -> DNF: # Return DNF expression pattern for the inverse of a comparison expr = _get_blockwise_input(0, indices, dsk).to_list_tuple() new_expr = [] count = 0 for conjunction in expr: new_conjunction = [] for col, op, val in conjunction: count += 1 new_conjunction.append((col, _inv(op), val)) new_expr.append(DNF._And(new_conjunction)) if count > 1: # Havent taken the time to think through # general inversion yet. raise ValueError("inv(DNF) case not implemented.") return DNF(DNF._Or(new_expr)) ================================================ FILE: dask_sql/physical/utils/groupby.py ================================================ import dask.dataframe as dd from dask_sql.utils import new_temporary_column def get_groupby_with_nulls_cols( df: dd.DataFrame, group_columns: list[str], additional_column_name: str = None ): """ SQL and dask are treating null columns a bit different: SQL will put them to the front, dask will just ignore them Therefore we use the same trick as fugue does: we will group by both the NaN and the real column value """ if additional_column_name is None: additional_column_name = new_temporary_column(df) group_columns_and_nulls = [] for group_column in group_columns: is_null_column = group_column.isnull() non_nan_group_column = group_column.fillna(0) # split_out doesn't work if both columns have the same name is_null_column.name = f"{is_null_column.name}_{new_temporary_column(df)}" group_columns_and_nulls += [is_null_column, non_nan_group_column] if not group_columns_and_nulls: # This can happen in statements like # SELECT SUM(x) FROM data # without any groupby statement group_columns_and_nulls = [additional_column_name] return group_columns_and_nulls ================================================ FILE: dask_sql/physical/utils/ml_classes.py ================================================ def get_cpu_classes(): try: from sklearn.utils import all_estimators cpu_classes = { k: v.__module__ + "." + v.__qualname__ for k, v in all_estimators() } except ImportError: cpu_classes = {} cpu_classes = add_boosting_classes(cpu_classes) return cpu_classes def get_gpu_classes(): gpu_classes = { # cuml.dask "DBSCAN": "cuml.dask.cluster.dbscan.DBSCAN", "KMeans": "cuml.dask.cluster.kmeans.KMeans", "PCA": "cuml.dask.decomposition.pca.PCA", "TruncatedSVD": "cuml.dask.decomposition.tsvd.TruncatedSVD", "RandomForestClassifier": "cuml.dask.ensemble.randomforestclassifier.RandomForestClassifier", "RandomForestRegressor": "cuml.dask.ensemble.randomforestregressor.RandomForestRegressor", "LogisticRegression": "cuml.dask.extended.linear_model.logistic_regression.LogisticRegression", "TfidfTransformer": "cuml.dask.feature_extraction.text.tfidf_transformer.TfidfTransformer", "LinearRegression": "cuml.dask.linear_model.linear_regression.LinearRegression", "Ridge": "cuml.dask.linear_model.ridge.Ridge", "Lasso": "cuml.dask.linear_model.lasso.Lasso", "ElasticNet": "cuml.dask.linear_model.elastic_net.ElasticNet", "UMAP": "cuml.dask.manifold.umap.UMAP", "MultinomialNB": "cuml.dask.naive_bayes.naive_bayes.MultinomialNB", "NearestNeighbors": "cuml.dask.neighbors.nearest_neighbors.NearestNeighbors", "KNeighborsClassifier": "cuml.dask.neighbors.kneighbors_classifier.KNeighborsClassifier", "KNeighborsRegressor": "cuml.dask.neighbors.kneighbors_regressor.KNeighborsRegressor", "LabelBinarizer": "cuml.dask.preprocessing.label.LabelBinarizer", "OneHotEncoder": "cuml.dask.preprocessing.encoders.OneHotEncoder", "LabelEncoder": "cuml.dask.preprocessing.LabelEncoder.LabelEncoder", "CD": "cuml.dask.solvers.cd.CD", # cuml "Base": "cuml.internals.base.Base", "Handle": "cuml.common.handle.Handle", "AgglomerativeClustering": "cuml.cluster.agglomerative.AgglomerativeClustering", "HDBSCAN": "cuml.cluster.hdbscan.HDBSCAN", "IncrementalPCA": "cuml.decomposition.incremental_pca.IncrementalPCA", "ForestInference": "cuml.fil.fil.ForestInference", "KernelRidge": "cuml.kernel_ridge.kernel_ridge.KernelRidge", "MBSGDClassifier": "cuml.linear_model.mbsgd_classifier.MBSGDClassifier", "MBSGDRegressor": "cuml.linear_model.mbsgd_regressor.MBSGDRegressor", "TSNE": "cuml.manifold.t_sne.TSNE", "KernelDensity": "cuml.neighbors.kernel_density.KernelDensity", "GaussianRandomProjection": "cuml.random_projection.random_projection.GaussianRandomProjection", "SparseRandomProjection": "cuml.random_projection.random_projection.SparseRandomProjection", "SGD": "cuml.solvers.sgd.SGD", "QN": "cuml.solvers.qn.QN", "SVC": "cuml.svm.SVC", "SVR": "cuml.svm.SVR", "LinearSVC": "cuml.svm.LinearSVC", "LinearSVR": "cuml.svm.LinearSVR", "ARIMA": "cuml.tsa.arima.ARIMA", "AutoARIMA": "cuml.tsa.auto_arima.AutoARIMA", "ExponentialSmoothing": "cuml.tsa.holtwinters.ExponentialSmoothing", # sklearn "Binarizer": "cuml.preprocessing.Binarizer", "KernelCenterer": "cuml.preprocessing.KernelCenterer", "MinMaxScaler": "cuml.preprocessing.MinMaxScaler", "MaxAbsScaler": "cuml.preprocessing.MaxAbsScaler", "Normalizer": "cuml.preprocessing.Normalizer", "PolynomialFeatures": "cuml.preprocessing.PolynomialFeatures", "PowerTransformer": "cuml.preprocessing.PowerTransformer", "QuantileTransformer": "cuml.preprocessing.QuantileTransformer", "RobustScaler": "cuml.preprocessing.RobustScaler", "StandardScaler": "cuml.preprocessing.StandardScaler", "SimpleImputer": "cuml.preprocessing.SimpleImputer", "MissingIndicator": "cuml.preprocessing.MissingIndicator", "KBinsDiscretizer": "cuml.preprocessing.KBinsDiscretizer", "FunctionTransformer": "cuml.preprocessing.FunctionTransformer", "ColumnTransformer": "cuml.compose.ColumnTransformer", "GridSearchCV": "sklearn.model_selection.GridSearchCV", "Pipeline": "sklearn.pipeline.Pipeline", # Other "UniversalBase": "cuml.internals.base.UniversalBase", "Lars": "cuml.experimental.linear_model.lars.Lars", "TfidfVectorizer": "cuml.feature_extraction._tfidf_vectorizer.TfidfVectorizer", "CountVectorizer": "cuml.feature_extraction._vectorizers.CountVectorizer", "HashingVectorizer": "cuml.feature_extraction._vectorizers.HashingVectorizer", "StratifiedKFold": "cuml.model_selection._split.StratifiedKFold", "OneVsOneClassifier": "cuml.multiclass.multiclass.OneVsOneClassifier", "OneVsRestClassifier": "cuml.multiclass.multiclass.OneVsRestClassifier", "MulticlassClassifier": "cuml.multiclass.multiclass.MulticlassClassifier", "BernoulliNB": "cuml.naive_bayes.naive_bayes.BernoulliNB", "GaussianNB": "cuml.naive_bayes.naive_bayes.GaussianNB", "ComplementNB": "cuml.naive_bayes.naive_bayes.ComplementNB", "CategoricalNB": "cuml.naive_bayes.naive_bayes.CategoricalNB", "TargetEncoder": "cuml.preprocessing.TargetEncoder", "PorterStemmer": "cuml.preprocessing.text.stem.porter_stemmer.PorterStemmer", } gpu_classes = add_boosting_classes(gpu_classes) return gpu_classes def add_boosting_classes(my_classes): my_classes["LGBMModel"] = "lightgbm.LGBMModel" my_classes["LGBMClassifier"] = "lightgbm.LGBMClassifier" my_classes["LGBMRegressor"] = "lightgbm.LGBMRegressor" my_classes["LGBMRanker"] = "lightgbm.LGBMRanker" my_classes["XGBRegressor"] = "xgboost.XGBRegressor" my_classes["XGBClassifier"] = "xgboost.XGBClassifier" my_classes["XGBRanker"] = "xgboost.XGBRanker" my_classes["XGBRFRegressor"] = "xgboost.XGBRFRegressor" my_classes["XGBRFClassifier"] = "xgboost.XGBRFClassifier" my_classes["DaskXGBClassifier"] = "xgboost.dask.DaskXGBClassifier" my_classes["DaskXGBRegressor"] = "xgboost.dask.DaskXGBRegressor" my_classes["DaskXGBRanker"] = "xgboost.dask.DaskXGBRanker" my_classes["DaskXGBRFRegressor"] = "xgboost.dask.DaskXGBRFRegressor" my_classes["DaskXGBRFClassifier"] = "xgboost.dask.DaskXGBRFClassifier" return my_classes ================================================ FILE: dask_sql/physical/utils/sort.py ================================================ import dask.dataframe as dd import pandas as pd from dask import config as dask_config from dask.utils import M from dask_sql.utils import is_cudf_type def apply_sort( df: dd.DataFrame, sort_columns: list[str], sort_ascending: list[bool], sort_null_first: list[bool], sort_num_rows: int = None, ) -> dd.DataFrame: # when sort_values doesn't support lists of ascending / null # position booleans, we can still do the sort provided that # the list(s) are homogeneous: single_ascending = len(set(sort_ascending)) == 1 single_null_first = len(set(sort_null_first)) == 1 if is_topk_optimizable( df=df, sort_columns=sort_columns, single_ascending=single_ascending, sort_null_first=sort_null_first, sort_num_rows=sort_num_rows, ): return topk_sort( df=df, sort_columns=sort_columns, sort_ascending=sort_ascending, sort_num_rows=sort_num_rows, ) else: # Pre persist before sort to avoid duplicate compute df = df.persist() # pandas / cudf don't support lists of null positions if df.npartitions == 1 and single_null_first: return df.map_partitions( M.sort_values, by=sort_columns, ascending=sort_ascending, na_position="first" if sort_null_first[0] else "last", ).persist() # dask / dask-cudf don't support lists of ascending / null positions if len(sort_columns) == 1 or ( is_cudf_type(df) and single_ascending and single_null_first ): try: return df.sort_values( by=sort_columns, ascending=sort_ascending[0], na_position="first" if sort_null_first[0] else "last", # ignore_index=True, ).persist() except ValueError: pass # if standard `sort_values` can't handle ascending / null position params, # we extend it using our custom sort function return df.sort_values( by=sort_columns[0], ascending=sort_ascending[0], na_position="first" if sort_null_first[0] else "last", sort_function=(sort_partition_func), sort_function_kwargs={ "sort_columns": sort_columns, "sort_ascending": sort_ascending, "sort_null_first": sort_null_first, }, ).persist() def topk_sort( df: dd.DataFrame, sort_columns: list[str], sort_ascending: list[bool], sort_num_rows: int = None, ): if sort_ascending[0]: return df.nsmallest(n=sort_num_rows, columns=sort_columns) else: return df.nlargest(n=sort_num_rows, columns=sort_columns) def sort_partition_func( partition: pd.DataFrame, sort_columns: list[str], sort_ascending: list[bool], sort_null_first: list[bool], **kwargs, ): if partition.empty: return partition # Trick: https://github.com/pandas-dev/pandas/issues/17111 # to make sorting faster # With that, we can also allow for different NaN-orders by column # For this, we start with the last sort column # and use mergesort when we move to the front for col, asc, null_first in reversed( list(zip(sort_columns, sort_ascending, sort_null_first)) ): if null_first: na_position = "first" else: na_position = "last" partition = partition.sort_values( by=[col], ascending=asc, na_position=na_position, kind="mergesort" ) return partition def is_topk_optimizable( df: dd.DataFrame, sort_columns: list[str], single_ascending: bool, sort_null_first: list[bool], sort_num_rows: int = None, ): if ( sort_num_rows is None or not single_ascending or any(sort_null_first) # pandas/cudf don't support nsmallest/nlargest with object dtypes or any(df[sort_columns].dtypes == "object") or ( sort_num_rows * len(df.columns) > dask_config.get("sql.sort.topk-nelem-limit") ) ): return False return True ================================================ FILE: dask_sql/physical/utils/statistics.py ================================================ from __future__ import annotations import itertools import logging from collections import defaultdict from functools import lru_cache import dask import dask.dataframe as dd import pyarrow.parquet as pq from dask.dataframe.io.parquet.arrow import ArrowDatasetEngine from dask.dataframe.io.parquet.core import ParquetFunctionWrapper from dask.dataframe.io.utils import _is_local_fs from dask.delayed import delayed from dask.layers import DataFrameIOLayer from dask.utils_test import hlg_layer logger = logging.getLogger(__name__) def parquet_statistics( ddf: dd.DataFrame, columns: list | None = None, parallel: int | False | None = None, **compute_kwargs, ) -> list[dict] | None: """Extract Parquet statistics from a Dask DataFrame collection WARNING: This API is experimental Parameters ---------- ddf Dask-DataFrame object to extract Parquet statistics from. columns List of columns to collect min/max statistics for. If ``None`` (the default), only 'num-rows' statistics will be collected. parallel The number of distinct files to collect statistics for within a distinct ``dask.delayed`` task. If ``False``, all statistics will be parsed on the client process. If ``None``, the value will be set to 16 for remote filesystem (e.g s3) and ``False`` otherwise. Default is ``None``. **compute_kwargs Key-word arguments to pass through to ``dask.compute`` when ``parallel`` is not ``False``. Returns ------- statistics List of Parquet statistics. Each list element corresponds to a distinct partition in ``ddf``. Each element of ``statistics`` will correspond to a dictionary with 'num-rows' and 'columns' keys:: ``{'num-rows': 1024, 'columns': [...]}`` If column statistics are available, each element of the list stored under the "columns" key will correspond to a dictionary with "name", "min", and "max" keys:: ``{'name': 'col0', 'min': 0, 'max': 100}`` """ # Check that we have a supported `ddf` object if not isinstance(ddf, dd.DataFrame): raise ValueError(f"Expected Dask DataFrame, got {type(ddf)}.") # Be strict about columns argument if columns: if not isinstance(columns, list): raise ValueError(f"Expected columns to be a list, got {type(columns)}.") elif not set(columns).issubset(set(ddf.columns)): raise ValueError(f"columns={columns} must be a subset of {ddf.columns}") # Extract "read-parquet" layer from ddf try: layer = hlg_layer(ddf.dask, "read-parquet") except KeyError: layer = None # Make sure we are dealing with a # ParquetFunctionWrapper-based DataFrameIOLayer if not isinstance(layer, DataFrameIOLayer) or not isinstance( layer.io_func, ParquetFunctionWrapper ): logger.debug( f"Could not extract Parquet statistics from {ddf}." f"\nAttempted IO layer: {layer}" ) return None # Collect statistics using layer information parts = layer.inputs fs = layer.io_func.fs engine = layer.io_func.engine if not issubclass(engine, ArrowDatasetEngine): logger.debug( f"Could not extract Parquet statistics from {ddf}." f"\nUnsupported parquet engine: {engine}" ) return None # Set default if parallel is None: parallel = False if _is_local_fs(fs) else 16 parallel = int(parallel) if parallel: # Group parts corresponding to the same file. # A single task should always parse statistics # for all these parts at once (since they will # all be in the same footer) groups = defaultdict(list) for part in parts: for p in [part] if isinstance(part, dict) else part: path = p.get("piece")[0] groups[path].append(p) group_keys = list(groups.keys()) # Compute and return flattened result func = delayed(_read_partition_stats_group) result = dask.compute( [ func( list( itertools.chain( *[groups[k] for k in group_keys[i : i + parallel]] ) ), fs, engine, columns=columns, ) for i in range(0, len(group_keys), parallel) ], **(compute_kwargs or {}), )[0] return list(itertools.chain(*result)) else: # Serial computation on client return _read_partition_stats_group(parts, fs, engine, columns=columns) def _read_partition_stats_group(parts, fs, engine, columns=None): def _read_partition_stats(part, fs, columns=None): # Helper function to read Parquet-metadata # statistics for a single partition if not isinstance(part, list): part = [part] column_stats = {} num_rows = 0 columns = columns or [] for p in part: piece = p["piece"] path = piece[0] row_groups = None if piece[1] == [None] else piece[1] md = _get_md(path, fs) if row_groups is None: row_groups = list(range(md.num_row_groups)) for rg in row_groups: row_group = md.row_group(rg) num_rows += row_group.num_rows for i in range(row_group.num_columns): col = row_group.column(i) name = col.path_in_schema if name in columns: if col.statistics and col.statistics.has_min_max: if name in column_stats: column_stats[name]["min"] = min( column_stats[name]["min"], col.statistics.min ) column_stats[name]["max"] = max( column_stats[name]["max"], col.statistics.max ) else: column_stats[name] = { "min": col.statistics.min, "max": col.statistics.max, } # Convert dict-of-dict to list-of-dict to be consistent # with current `dd.read_parquet` convention (for now) column_stats_list = [ { "name": name, "min": column_stats[name]["min"], "max": column_stats[name]["max"], } for name in column_stats.keys() ] return {"num-rows": num_rows, "columns": column_stats_list} @lru_cache(maxsize=1) def _get_md(path, fs): # Caching utility to avoid parsing the same footer # metadata multiple times with fs.open(path, default_cache="none") as f: return pq.ParquetFile(f).metadata # Helper function used by _extract_statistics return [_read_partition_stats(part, fs, columns=columns) for part in parts] ================================================ FILE: dask_sql/server/__init__.py ================================================ ================================================ FILE: dask_sql/server/app.py ================================================ import asyncio import logging from argparse import ArgumentParser from uuid import uuid4 import dask.distributed import uvicorn from fastapi import FastAPI, HTTPException, Request from uvicorn import Config, Server from dask_sql.context import Context from dask_sql.server.presto_jdbc import create_meta_data from dask_sql.server.responses import DataResults, ErrorResults, QueryResults app = FastAPI() logger = logging.getLogger(__name__) @app.get("/v1/empty") async def empty(request: Request): """ Helper endpoint returning an empty result. """ return QueryResults(request=request) @app.delete("/v1/cancel/{uuid}") async def cancel(uuid: str, request: Request): """ Cancel an already running computation """ logger.debug(f"Canceling the request with uuid {uuid}") try: future = request.app.future_list[uuid] except KeyError: raise HTTPException(status_code=404, detail="uuid not found") future.cancel() del request.app.future_list[uuid] return {"status": "ok"} @app.get("/v1/status/{uuid}") async def status(uuid: str, request: Request): """ Return the status (or the result) of an already running calculation """ logger.debug(f"Accessing the request with uuid {uuid}") try: future = request.app.future_list[uuid] except KeyError: raise HTTPException(status_code=404, detail="uuid not found") if future.done(): logger.debug(f"{uuid} is already finished, returning data") df = future.result() del request.app.future_list[uuid] return DataResults(df, request=request) logger.debug(f"{uuid} is not already finished") status_url = str(request.url) return QueryResults(request=request, next_url=status_url) @app.post("/v1/statement") async def query(request: Request): """ Main endpoint returning query results in the presto on wire format. """ try: sql = (await request.body()).decode().strip() # required for PrestoDB JDBC driver compatibility # replaces queries to unsupported `system` catalog with queries to `system_jdbc` # schema created by `create_meta_data(context)` when `jdbc_metadata=True` # TODO: explore Trino which should make JDBC compatibility easier but requires # changing response headers (see https://github.com/dask-contrib/dask-sql/pull/351) sql = sql.replace("system.jdbc", "system_jdbc") df = request.app.c.sql(sql) if df is None: return DataResults(df, request) uuid = str(uuid4()) request.app.future_list[uuid] = request.app.client.compute(df) logger.debug(f"Registering {sql} with uuid {uuid}.") status_url = str( request.url.replace(path=request.app.url_path_for("status", uuid=uuid)) ) cancel_url = str( request.url.replace(path=request.app.url_path_for("cancel", uuid=uuid)) ) return QueryResults(request=request, next_url=status_url, cancel_url=cancel_url) except Exception as e: return ErrorResults(e, request=request) def run_server( context: Context = None, client: dask.distributed.Client = None, host: str = "0.0.0.0", port: int = 8080, startup=False, log_level=None, blocking: bool = True, jdbc_metadata: bool = False, ): # pragma: no cover """ Run a HTTP server for answering SQL queries using ``dask-sql``. It uses the `Presto Wire Protocol `_ for communication. This means, it has a single POST endpoint `/v1/statement`, which answers SQL queries (as string in the body) with the output as a JSON (in the format described in the documentation above). Every SQL expression that ``dask-sql`` understands can be used here. See :ref:`server` for more information. Note: The presto protocol also includes some statistics on the query in the response. These statistics are currently only filled with placeholder variables. Args: context (:obj:`dask_sql.Context`): If set, use this context instead of an empty one. client (:obj:`dask.distributed.Client`): If set, use this dask client instead of a new one. host (:obj:`str`): The host interface to listen on (defaults to all interfaces) port (:obj:`int`): The port to listen on (defaults to 8080) startup (:obj:`bool`): Whether to wait until Apache Calcite was loaded log_level: (:obj:`str`): The log level of the server and dask-sql blocking: (:obj:`bool`): If running in an environment with an event loop (e.g. a jupyter notebook), do not block. The server can be stopped with `context.stop_server()` afterwards. jdbc_metadata: (:obj:`bool`): If enabled create JDBC metadata tables using schemas and tables in the current dask_sql context Example: It is possible to run an SQL server by using the CLI script ``dask-sql-server`` or by calling this function directly in your user code: .. code-block:: python from dask_sql import run_server # Create your pre-filled context c = Context() ... run_server(context=c) After starting the server, it is possible to send queries to it, e.g. with the `presto CLI `_ or via sqlalchemy (e.g. using the `PyHive `_ package): .. code-block:: python from sqlalchemy.engine import create_engine engine = create_engine('presto://localhost:8080/') import pandas as pd pd.read_sql_query("SELECT 1 + 1", con=engine) Of course, it is also possible to call the usual ``CREATE TABLE`` commands. If in a jupyter notebook, you should run the following code .. code-block:: python from dask_sql import Context c = Context() c.run_server(blocking=False) ... c.stop_server() Note: When running in a jupyter notebook without blocking, it is not possible to access the SQL server from within the notebook, e.g. using sqlalchemy. Doing so will deadlock infinitely. """ if context is None: context = Context() _init_app(app, context=context, client=client) if jdbc_metadata: create_meta_data(context) if startup: app.c.sql("SELECT 1 + 1").compute() config = Config(app, host=host, port=port, log_level=log_level) server = Server(config=config) if blocking: server.run() else: loop = asyncio.get_event_loop() loop.create_task(server.serve()) context.sql_server = server def main(): # pragma: no cover """ CLI version of the :func:`run_server` function. """ parser = ArgumentParser() parser.add_argument( "--host", default="0.0.0.0", help="The host interface to listen on (defaults to all interfaces)", ) parser.add_argument( "--port", default=8080, help="The port to listen on (defaults to 8080)" ) parser.add_argument( "--scheduler-address", default=None, help="Connect to this dask scheduler if given", ) parser.add_argument( "--log-level", default=None, help="Set the log level of the server. Defaults to info.", choices=uvicorn.config.LOG_LEVELS, ) parser.add_argument( "--load-test-data", default=False, action="store_true", help="Preload some test data.", ) parser.add_argument( "--startup", default=False, action="store_true", help="Wait until Apache Calcite was properly loaded", ) args = parser.parse_args() client = None if args.scheduler_address: client = dask.distributed.Client(args.scheduler_address) context = Context() if args.load_test_data: df = dask.datasets.timeseries(freq="1d").reset_index(drop=False) context.create_table("timeseries", df.persist()) run_server( context=context, client=client, host=args.host, port=args.port, startup=args.startup, log_level=args.log_level, ) def _init_app( app: FastAPI, context: Context = None, client: dask.distributed.Client = None, ): app.c = context app.future_list = {} try: client = client or dask.distributed.Client.current() except ValueError: client = dask.distributed.Client() app.client = client ================================================ FILE: dask_sql/server/presto_jdbc.py ================================================ import logging import pandas as pd from dask_sql.context import Context logger = logging.getLogger(__name__) def create_meta_data(c: Context): """ Creates the schema, table and column data for prestodb JDBC driver so that data can be viewed in a database tool like DBeaver. It doesn't create a catalog entry although JDBC expects one as dask-sql doesn't support catalogs. For both catalogs and procedures empty placeholder tables are created. The meta-data appears in a separate schema called system_jdbc largely because the JDBC driver tries to access system.jdbc and it sufficiently so shouldn't clash with other schemas. A function is required in the /v1/statement to change system.jdbc to system_jdbc and ignore order by statements from the driver (as adjust_for_presto_sql above) :param c: Context containing created tables :return: """ if c is None: logger.warning("Context None: jdbc meta data not created") return catalog = "" system_schema = "system_jdbc" c.create_schema(system_schema) # TODO: add support for catalogs in presto interface # see https://github.com/dask-contrib/dask-sql/pull/351 # if catalog and len(catalog.strip()) > 0: # catalogs = pd.DataFrame().append(create_catalog_row(catalog), ignore_index=True) # c.create_table("catalogs", catalogs, schema_name=system_schema) schemas = pd.DataFrame(create_schema_row(), index=[0]) c.create_table("schemas", schemas, schema_name=system_schema) schema_rows = [] tables = pd.DataFrame(create_table_row(), index=[0]) c.create_table("tables", tables, schema_name=system_schema) table_rows = [] columns = pd.DataFrame(create_column_row(), index=[0]) c.create_table("columns", columns, schema_name=system_schema) column_rows = [] for schema_name, schema in c.schema.items(): schema_rows.append(create_schema_row(catalog, schema_name)) for table_name, dc in schema.tables.items(): df = dc.df logger.info(f"schema ${schema_name}, table {table_name}, {df}") table_rows.append(create_table_row(catalog, schema_name, table_name)) pos: int = 0 for column in df.columns: pos = pos + 1 logger.debug(f"column {column}") dtype = "VARCHAR" if df[column].dtype == "int64" or df[column].dtype == "int": dtype = "INTEGER" elif df[column].dtype == "float64" or df[column].dtype == "float": dtype = "FLOAT" elif ( df[column].dtype == "datetime" or df[column].dtype == "datetime64[ns]" ): dtype = "TIMESTAMP" column_rows.append( create_column_row( catalog, schema_name, table_name, dtype, df[column].name, str(pos), ) ) schemas = pd.DataFrame(schema_rows) c.create_table("schemas", schemas, schema_name=system_schema) tables = pd.DataFrame(table_rows) c.create_table("tables", tables, schema_name=system_schema) columns = pd.DataFrame(column_rows) c.create_table("columns", columns, schema_name=system_schema) logger.info(f"jdbc meta data ready for {len(table_rows)} tables") def create_catalog_row(catalog: str = ""): return {"TABLE_CAT": catalog} def create_schema_row(catalog: str = "", schema: str = ""): return {"TABLE_CATALOG": catalog, "TABLE_SCHEM": schema} def create_table_row(catalog: str = "", schema: str = "", table: str = ""): return { "TABLE_CAT": catalog, "TABLE_SCHEM": schema, "TABLE_NAME": table, "TABLE_TYPE": "", "REMARKS": "", "TYPE_CAT": "", "TYPE_SCHEM": "", "TYPE_NAME": "", "SELF_REFERENCING_COL_NAME": "", "REF_GENERATION": "", } def create_column_row( catalog: str = "", schema: str = "", table: str = "", dtype: str = "", column: str = "", pos: str = "", ): return { "TABLE_CAT": catalog, "TABLE_SCHEM": schema, "TABLE_NAME": table, "COLUMN_NAME": column, "DATA_TYPE": dtype, "TYPE_NAME": dtype, "COLUMN_SIZE": "", "BUFFER_LENGTH": "", "DECIMAL_DIGITS": "", "NUM_PREC_RADIX": "", "NULLABLE": "", "REMARKS": "", "COLUMN_DEF": "", "SQL_DATA_TYPE": dtype, "SQL_DATETIME_SUB": "", "CHAR_OCTET_LENGTH": "", "ORDINAL_POSITION": pos, "IS_NULLABLE": "", "SCOPE_CATALOG": "", "SCOPE_SCHEMA": "", "SCOPE_TABLE": "", "SOURCE_DATA_TYPE": "", "IS_AUTOINCREMENT": "", "IS_GENERATEDCOLUMN": "", } ================================================ FILE: dask_sql/server/responses.py ================================================ import uuid import dask.dataframe as dd import numpy as np import pandas as pd from fastapi import Request from dask_sql.mappings import python_to_sql_type class StageStats: def __init__(self): self.stageId = "" self.state = "" self.done = True self.nodes = 0 self.totalSplits = 0 self.queuedSplits = 0 self.runningSplits = 0 self.completedSplits = 0 self.cpuTimeMillis = 0 self.wallTimeMillis = 0 self.processedRows = 0 self.processedBytes = 0 self.subStages = [] class StatementStats: def __init__(self): self.state = "" self.queued = False self.scheduled = False self.nodes = 0 self.totalSplits = 0 self.queuedSplits = 0 self.runningSplits = 0 self.completedSplits = 0 self.cpuTimeMillis = 0 self.wallTimeMillis = 0 self.queuedTimeMillis = 0 self.elapsedTimeMillis = 0 self.processedRows = 0 self.processedBytes = 0 self.peakMemoryBytes = 0 self.peakTotalMemoryBytes = 0 self.peakTaskTotalMemoryBytes = 0 self.spilledBytes = 0 self.rootStage = StageStats() class QueryResults: def __init__(self, request: Request, next_url: str = None, cancel_url: str = None): empty_url = str(request.url.replace(path=request.app.url_path_for("empty"))) self.id = str(uuid.uuid4()) self.infoUri = empty_url if next_url: self.nextUri = next_url if cancel_url: self.partialCancelUri = cancel_url self.stats = StatementStats() self.warnings = [] class DataResults(QueryResults): @staticmethod def get_column_description(df): sql_types = [str(python_to_sql_type(t)).lower() for t in df.dtypes] column_names = df.columns return [ { "name": column_name, "type": sql_type, "typeSignature": { "rawType": sql_type, "arguments": [] if sql_type not in ("char", "varchar") else [{"kind": "LONG", "value": 10}], }, } for column_name, sql_type in zip(column_names, sql_types) ] @staticmethod def get_data_description(df): if hasattr(df, "to_pandas"): df = df.to_pandas() return [ DataResults.convert_row(row) for row in df.itertuples(index=False, name=None) ] @staticmethod def convert_cell(cell): try: if pd.isna(cell): return None elif np.isnan(cell): # pragma: no cover return "NaN" elif np.isposinf(cell): return "+Infinity" elif np.isneginf(cell): # pragma: no cover return "-Infinity" except TypeError: # pragma: no cover pass try: return cell.item() except AttributeError: pass return cell @staticmethod def convert_row(row): return [DataResults.convert_cell(cell) for cell in row] def __init__(self, df: dd.DataFrame, request: Request): super().__init__(request) if df is None: return self.columns = self.get_column_description(df) self.data = self.get_data_description(df) class ErrorResults(QueryResults): def __init__(self, error: Exception, request: Request): super().__init__(request) self.error = QueryError(error) class QueryError: def __init__(self, error: Exception): self.message = str(error) self.errorCode = 0 self.errorName = str(type(error)) self.errorType = "USER_ERROR" # FIXME: ParserErrors currently don't contain information on where the syntax error occurred # try: # self.errorLocation = { # "lineNumber": error.from_line + 1, # "columnNumber": error.from_col + 1, # } # except AttributeError: # pragma: no cover # pass ================================================ FILE: dask_sql/sql-schema.yaml ================================================ properties: sql: type: object properties: aggregate: type: object properties: split_out: type: integer description: | Number of output partitions from an aggregation operation split_every: type: [integer, "null"] description: | Number of branches per reduction step from an aggregation operation. identifier: type: object properties: case_sensitive: type: boolean description: | Whether sql identifiers are considered case sensitive while parsing. join: type: object properties: broadcast: type: [boolean, number, "null"] description: | If boolean, it determines whether all joins should use the broadcast join algorithm. If float, it's a value denoting dask's likelihood of selecting a broadcast join based codepath over a shuffle based join. Concretely, dask will select a broadcast based join algorithm if small_table.npartitions < log2(big_table.npartitions) * broadcast_bias Note: Forcing a broadcast join might lead to perf issues or OOM errors in cases where the broadcasted table is too large to fit on a single worker. limit: type: object properties: check-first-partition: type: boolean description: | Whether or not to check the first partition length when computing a LIMIT without an OFFSET on a table with a relatively simple Dask graph (i.e. only IO and/or partition-wise layers); checking partition length triggers a Dask graph computation which can be slow for complex queries, but can signicantly reduce memory usage when querying a small subset of a large table. Default is ``true``. optimize: type: boolean description: | Whether the first generated logical plan should be further optimized or used as is. predicate_pushdown: type: boolean description: | Whether to try pushing down filter predicates into IO (when possible). dynamic_partition_pruning: type: boolean description: | Whether to apply the dynamic partition pruning optimizer rule. optimizer: type: object properties: verbose: type: boolean description: | The dynamic partition pruning optimizer rule can sometimes result in extremely long c.explain() outputs which are not helpful to the user. Setting this option to true allows the user to see the entire output, while setting it to false truncates the output. Default is false. fact_dimension_ratio: type: [number, "null"] description: | Ratio of the size of the dimension tables to fact tables. Parameter for dynamic partition pruning and join reorder optimizer rules. max_fact_tables: type: [integer, "null"] description: | Maximum number of fact tables to allow in a join. Parameter for join reorder optimizer rule. preserve_user_order: type: [boolean, "null"] description: | Whether to preserve user-defined order of unfiltered dimensions. Parameter for join reorder optimizer rule. filter_selectivity: type: [number, "null"] description: | Constant to use when determining the number of rows produced by a filtered relation. Parameter for join reorder optimizer rule. sort: type: object properties: topk-nelem-limit: type: integer description: | Total number of elements below which dask-sql should attempt to apply the top-k optimization (when possible). ``nelem`` is defined as the limit or ``k`` value times the number of columns. Default is 1000000, corresponding to a LIMIT clause of 1 million in a 1 column table. mappings: type: object properties: decimal_support: type: string description: Decides how to handle decimal scalars/columns. ``"pandas"`` handling will treat decimals scalars and columns as floats and float64 columns, respectively, while ``"cudf"`` handling treats decimal scalars as ``decimal.Decimal`` objects and decimal columns as ``cudf.Decimal128Dtype`` columns, handling precision/scale accordingly. Default is ``"pandas"``, but ``"cudf"`` should be used if attempting to work with decimal columns on GPU. ================================================ FILE: dask_sql/sql.yaml ================================================ sql: aggregate: split_out: 1 split_every: null identifier: case_sensitive: True join: broadcast: null limit: check-first-partition: True optimize: True predicate_pushdown: True dynamic_partition_pruning: True optimizer: verbose: False fact_dimension_ratio: null max_fact_tables: null preserve_user_order: null filter_selectivity: null sort: topk-nelem-limit: 1000000 mappings: decimal_support: "pandas" ================================================ FILE: dask_sql/utils.py ================================================ import importlib import logging from collections import defaultdict from datetime import datetime from typing import Any from uuid import uuid4 import dask.dataframe as dd import numpy as np import pandas as pd from dask_sql._datafusion_lib import SqlTypeName from dask_sql.datacontainer import DataContainer from dask_sql.mappings import sql_to_python_value logger = logging.getLogger(__name__) def is_frame(df): """ Check if something is a dataframe (and not a scalar or none) """ return ( df is not None and not np.isscalar(df) and not isinstance(df, type(pd.NA)) and not isinstance(df, datetime) ) def is_datetime(obj): """ Check if a scalar or a series is of datetime type """ return pd.api.types.is_datetime64_any_dtype(obj) or isinstance(obj, datetime) def convert_to_datetime(df): """ Covert a scalar or a series to datetime type """ if is_frame(df): df = df.dt else: df = pd.to_datetime(df) return df def is_cudf_type(obj): """ Check if an object is a cuDF type """ types = [ str(type(obj)), str(getattr(obj, "_partition_type", "")), str(getattr(obj, "_meta", "")), ] return any("cudf" in obj_type for obj_type in types) class Pluggable: """ Helper class for everything which can be extended by plugins. Basically just a mapping of a name to the stored plugin for ever class. Please note that the plugins are stored in this single class, which makes simple extensions possible. """ __plugins = defaultdict(dict) @classmethod def add_plugin(cls, names, plugin, replace=True): """Add a plugin with the given name""" if isinstance(names, str): names = [names] if not replace and all(name in Pluggable.__plugins[cls] for name in names): return Pluggable.__plugins[cls].update({name: plugin for name in names}) @classmethod def get_plugin(cls, name): """Get a plugin with the given name""" return Pluggable.__plugins[cls][name] @classmethod def get_plugins(cls): """Return all registered plugins""" return list(Pluggable.__plugins[cls].values()) class ParsingException(Exception): """ Helper class to format validation and parsing SQL exception in a nicer way """ def __init__(self, sql, validation_exception_string): """ Create a new exception out of the SQL query and the exception text raise by calcite. """ super().__init__(validation_exception_string.strip()) class OptimizationException(Exception): """ Helper class for formatting exceptions that occur while trying to optimize a logical plan """ def __init__(self, exception_string): """ Create a new exception out of the SQL query and the exception from DataFusion """ super().__init__(exception_string.strip()) class LoggableDataFrame: """Small helper class to print resulting dataframes or series in logging messages""" def __init__(self, df): self.df = df def __str__(self): df = self.df if isinstance(df, pd.Series) or isinstance(df, dd.Series): return f"Series: {(df.name, df.dtype)}" if isinstance(df, pd.DataFrame) or isinstance(df, dd.DataFrame): return f"DataFrame: {[(col, dtype) for col, dtype in zip(df.columns, df.dtypes)]}" elif isinstance(df, DataContainer): cols = df.column_container.columns dtypes = {col: dtype for col, dtype in zip(df.df.columns, df.df.dtypes)} mapping = df.column_container.get_backend_by_frontend_index dtypes = [dtypes[mapping(index)] for index in range(len(cols))] return f"DataFrame: {[(col, dtype) for col, dtype in zip(cols, dtypes)]}" return f"Literal: {df}" def convert_sql_kwargs( sql_kwargs: dict[str, str], ) -> dict[str, Any]: """ Convert the Rust Vec of key/value pairs into a Dict containing the keys and values """ def convert_literal(value): if value.isCollection(): operator_mapping = { "SqlTypeName.ARRAY": list, "SqlTypeName.MAP": lambda x: dict(zip(x[::2], x[1::2])), "SqlTypeName.MULTISET": set, "SqlTypeName.ROW": tuple, } operator = operator_mapping[str(value.getSqlType())] operands = [convert_literal(o) for o in value.getOperandList()] return operator(operands) elif value.isKwargs(): return convert_sql_kwargs(value.getKwargs()) else: literal_type = value.getSqlType() literal_value = value.getSqlValue() if literal_type == SqlTypeName.VARCHAR: return value.getSqlValue() elif literal_type == SqlTypeName.BIGINT and "." in literal_value: literal_type = SqlTypeName.DOUBLE python_value = sql_to_python_value(literal_type, literal_value) return python_value return {key: convert_literal(value) for key, value in dict(sql_kwargs).items()} def import_class(name: str) -> type: """ Import a class with the given name by loading the module and referencing the class in the module """ module_path, class_name = name.rsplit(".", 1) module = importlib.import_module(module_path) return getattr(module, class_name) def new_temporary_column(df: dd.DataFrame) -> str: """Return a new column name which is currently not in use""" while True: col_name = str(uuid4()) if col_name not in df.columns: return col_name else: # pragma: no cover continue ================================================ FILE: docs/Makefile ================================================ # Minimal makefile for Sphinx documentation # # You can set these variables from the command line, and also # from the environment for the first two. SPHINXOPTS ?= SPHINXBUILD ?= sphinx-build SOURCEDIR = source BUILDDIR = build # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) .PHONY: help Makefile # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) ================================================ FILE: docs/environment.yml ================================================ name: dask-sql-docs channels: - conda-forge dependencies: - python=3.9 - sphinx>=4.0.0 - sphinx-tabs - dask-sphinx-theme>=2.0.3 - dask>=2024.4.1 - pandas>=1.4.0 - fugue>=0.7.3 # FIXME: https://github.com/fugue-project/fugue/issues/526 - triad<0.9.2 - fastapi>=0.92.0 - httpx>=0.24.1 - uvicorn>=0.14 - tzlocal>=2.1 - prompt_toolkit>=3.0.8 - pygments>=2.7.1 - tabulate - ucx-proc=*=cpu - rust=1.72 ================================================ FILE: docs/make.bat ================================================ @ECHO OFF pushd %~dp0 REM Command file for Sphinx documentation if "%SPHINXBUILD%" == "" ( set SPHINXBUILD=sphinx-build ) set SOURCEDIR=source set BUILDDIR=build if "%1" == "" goto help %SPHINXBUILD% >NUL 2>NUL if errorlevel 9009 ( echo. echo.The 'sphinx-build' command was not found. Make sure you have Sphinx echo.installed, then set the SPHINXBUILD environment variable to point echo.to the full path of the 'sphinx-build' executable. Alternatively you echo.may add the Sphinx directory to PATH. echo. echo.If you don't have Sphinx installed, grab it from echo.https://www.sphinx-doc.org/ exit /b 1 ) %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% goto end :help %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% :end popd ================================================ FILE: docs/requirements-docs.txt ================================================ sphinx>=4.0.0 sphinx-tabs dask-sphinx-theme>=3.0.0 dask>=2024.4.1 pandas>=1.4.0 fugue>=0.7.3 # FIXME: https://github.com/fugue-project/fugue/issues/526 triad<0.9.2 fastapi>=0.92.0 httpx>=0.24.1 uvicorn>=0.14 tzlocal>=2.1 prompt_toolkit>=3.0.8 pygments>=2.7.1 tabulate maturin>=1.3,<1.4 ================================================ FILE: docs/source/api.rst ================================================ .. _api: API Documentation ================= .. autoclass:: dask_sql.Context :members: :undoc-members: .. autofunction:: dask_sql.run_server .. autofunction:: dask_sql.cmd_loop .. autoclass:: dask_sql.integrations.fugue.DaskSQLExecutionEngine :members: .. autofunction:: dask_sql.integrations.fugue.fsql_dask ================================================ FILE: docs/source/best_practices.rst ================================================ .. _best_practices: Best Practices and Performance Tips =================================== Sort and Use Read Filtering --------------------------- If you often read by key ranges or perform lots of logic with groups of related records, you should consider using Dask Dataframe's `shuffle `_. This operation ensures that all rows of a given key will be within a single partition. This is helpful for querying records on a specific key or keys such as customer IDs or session keys, as it allows Dask to skip partitions based on the partition min and max values thus avoiding reading each record. This can save a large amount of IO time and is especially helpful when using a network file system. For example, querying a specific pickup time from a taxi dataset ends up returning a result with over 200 partitions as each of these partitions needs to be checked for that key. .. code-block:: python ddf = dd.read_parquet('/data/taxi_pq_2GB', split_row_groups=False) c.create_table('taxi_unsorted', ddf) c.sql("select * from taxi_unsorted where DAYOFMONTH(pickup_datetime) = 15").npartitions .. code-block:: 244 But, if you were to instead sort by the pickup time and use the ``DISTRIBUTE BY`` operation, which is equivalent to Dask Dataframe's shuffle, you can reduce the number of partitions in the result to 1. .. code-block:: python def intra_partition_sort(df, sort_keys): return df.sort_values(sort_keys) c.sql(""" SELECT DAYOFMONTH(pickup_datetime) AS dom, HOUR(pickup_datetime) AS hr, * FROM taxi_unsorted DISTRIBUTE BY dom """).map_partitions(intra_partition_sort, ['dom', 'hr']).to_parquet('/data/taxi_sorted') .. code-block:: python sorted_ddf = dd.read_parquet( '/data/taxi_sorted', split_row_groups=False, filters=[ [("dom", "==", 15)] ] ) c.create_table("taxi_sorted", sorted_ddf) c.sql("SELECT * FROM taxi_sorted WHERE dom = 15").npartitions .. code-block:: 1 This comes with a large corresponding boost in computation speed. For example, .. code-block:: python %%time c.sql("SELECT COUNT(*) FROM taxi_unsorted WHERE DAYOFMONTH(pickup_datetime) = 15").compute() .. code-block:: CPU times: user 2.4 s, sys: 275 ms, total: 2.68 s Wall time: 2.58 s .. code-block:: python %%time c.sql("SELECT COUNT(*) FROM taxi_sorted WHERE dom = 15").compute() .. code-block:: CPU times: user 318 ms, sys: 21.7 ms, total: 340 ms Wall time: 274 ms For a deeper dive into read filtering with Dask, check out `Filtered Reading with RAPIDS & Dask to Optimize ETL `_. In many cases Dask-SQL can automate sorting and read filtering with its predicate pushdown support. For example, the query .. code-block:: sql SELECT COUNT(*) FROM taxi WHERE DAYOFMONTH(pickup_datetime) = 15 would automatically perform the same sorting and read filtering logic as the previous section. Avoid Unnecessary Parallelism ----------------------------- Additionally, more tasks added to the Dask graph means more overhead added by the scheduler which can be a major performance inhibitor at large scales. For CPUs this isn't as much of an issue, as CPUs tend to have allowance for more workers and CPU tasks tend to take longer, so the additional overhead is relatively less impactful. But, for GPUs there's typically only one worker per GPU and tasks tend to be shorter, so the overhead added by a large number of tasks can greatly affect performance. Improve performance by only creating tasks as necessary. For example, splitting row groups creates more tasks so avoid this if possible. .. code-block:: python weather_dir = '/data/weather_pq_2GB/*.parquet' .. code-block:: sql CREATE OR REPLACE TABLE weather_split WITH ( location = '{weather_dir}', gpu=True, split_row_groups=True ) .. code-block:: sql SELECT COUNT(*) FROM weather_split WHERE type='PRCP' .. code-block:: sql CREATE OR REPLACE TABLE weather_nosplit WITH ( location = '{weather_dir}', gpu=True, split_row_groups=False ) .. code-block:: sql SELECT COUNT(*) FROM weather_nosplit WHERE type='PRCP' Use broadcast joins when possible --------------------------------- Joins and grouped aggregations typically require communication between workers, which can be expensive. Broadcast joins can help reduce this communication in the case of joining a small table to a large table by just sending the small table to each partition of the large table. However, in Dask-SQL this only works when the small table is a single partition. For example, if you read in some tables and concatenate them with a ``UNION ALL`` operation .. code-block:: sql CREATE OR REPLACE TABLE precip AS SELECT station_id, substring("date", 0, 4) as yr, substring("date", 5, 2) as mth, substring("date", 7, 2) as dy, val*1/10*0.0393701 as inches FROM weather_nosplit WHERE type='PRCP' .. code-block:: sql CREATE OR REPLACE TABLE atlanta_stations WITH ( location = '/data/atlanta_stations/*.parquet', gpu=True ) .. code-block:: sql CREATE OR REPLACE TABLE seattle_stations WITH ( location = '/data/seattle_stations/*.parquet', gpu=True ) .. code-block:: sql CREATE OR REPLACE TABLE city_stations AS SELECT * FROM atlanta_stations UNION ALL SELECT * FROM seattle_stations you get a new table that has two partitions. Then if you use it in a join .. code-block:: sql SELECT yr, city, CASE WHEN city='Atlanta' THEN sum(inches)/{atl_stations} ELSE sum(inches)/{seat_stations} END AS inches FROM precip JOIN city_stations ON precip.station_id = city_stations.station_id GROUP BY yr, city ORDER BY yr ASC Dask-SQL won't perform a broadcast join and will instead perform a traditional join with a corresponding slow compute time. However, if you were to repartition the smaller table to a single partition and rerun the operation .. code-block:: python c.create_table("city_stations", c.sql("select * from city_stations").repartition(npartitions=1)) .. code-block:: sql SELECT yr, city, CASE WHEN city='Atlanta' THEN sum(inches)/{atl_stations} ELSE sum(inches)/{seat_stations} END AS inches FROM precip JOIN city_stations ON precip.station_id = city_stations.station_id GROUP BY yr, city ORDER BY yr ASC Dask-SQL is able to recognize this as a broadcast join and the result is a significantly faster compute time. Dask-SQL also supports biasing the heuristic Dask uses to determine whether to use a broadcast join through the ``sql.join.broadcast`` config option. This option passes either a boolean or a float value to the ``broadcast`` argument in Dask's `merge `_ function. In the case of passing a float, a larger value makes Dask more likely to use a broadcast join. For example, .. code-block:: python c.sql(query, config_options={"sql.join.broadcast": True}) would instruct Dask to always use a broadcast join if supported for the query whereas .. code-block:: python c.sql(query, config_options={"sql.join.broadcast": 0.7}) would instruct Dask to use ``0.7`` as the ``broadcast_bias`` in its heuristic for deciding whether to use a broadcast join. Optimize Partition Sizes for GPUs --------------------------------- File formats like `Apache ORC `_ and `Apache Parquet `_ are designed so that they can be pulled from disk and be deserialized by CPUs quickly. However, loading data into GPUs has a substantial additional cost in the form of transfers from CPU to GPU memory. Minimizing that cost is often achieved by increasing partition size. Even when using Dask-SQL on GPUs, upstream CPU systems will likely produce small files resulting in small partitions. It's worth taking the time to repartition to larger partition sizes before querying the files on GPUs, especially when querying the same files multiple times. There's no single optimal size so choose a size that's tuned for your workflow. Operations like joins and concatenations greatly increase GPU memory utilization, even if temporarily, but if you're not performing many of these operations, the larger the partition size the better. Larger partition sizes increase disk to GPU throughput and keep GPU utilization higher for faster runtimes. We recommend a starting point of around 2gb uncompressed data per partition for GPUs. It's usually not necessary to change from default settings when running Dask-SQL on CPUs, but if you want to manually set partition sizes, we've found 128-256mb per partition to be a good starting place. ================================================ FILE: docs/source/cmd.rst ================================================ .. _cmd: Command Line Tool ================= It is also possible to run a small CLI tool for testing out some SQL commands quickly. You can either call the CLI tool (after installation) directly .. code-block:: bash dask-sql or by running these lines of code .. code-block:: python from dask_sql import cmd_loop cmd_loop() Some options can be set, e.g. to preload some testdata. Have a look into :func:`~dask_sql.cmd_loop` or call .. code-block:: bash dask-sql --help Of course, it is also possible to call the usual ``CREATE TABLE`` commands. Very similar as described in :ref:`server`, it is possible to preregister your own data sources or choose a dask scheduler to connect to. ================================================ FILE: docs/source/conf.py ================================================ # Configuration file for the Sphinx documentation builder. # # This file only contains a selection of the most common options. For a full # list see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html # contents of docs/conf.py # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # import os import sys from datetime import datetime # -- Path setup -------------------------------------------------------------- sys.path.insert(0, os.path.abspath("..")) # -- Project information ----------------------------------------------------- project = "dask-sql" copyright = f"{datetime.today().year}, Nils Braun" author = "Nils Braun" # -- General configuration --------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ "sphinx.ext.autodoc", "sphinx.ext.napoleon", "sphinx_tabs.tabs", "dask_sphinx_theme.ext.dask_config_sphinx_ext", ] # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # html_theme = "dask_sphinx_theme" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". # html_static_path = ["_static"] # Make sure to reference the correct master document master_doc = "index" # Do not show type mappings autodoc_typehints = "none" # disable collapsible tabs sphinx_tabs_disable_tab_closing = True ================================================ FILE: docs/source/configuration.rst ================================================ .. _configuration: Configuration in Dask-SQL ========================== ``dask-sql`` supports a list of configuration options to configure behavior of certain operations. ``dask-sql`` uses `Dask's config `_ module and configuration options can be specified with YAML files, via environment variables, or directly, either through the `dask.config.set `_ method or the ``config_options`` argument in the :func:`dask_sql.Context.sql` method. Configuration Reference ----------------------- .. dask-config-block:: :location: sql :config: https://raw.githubusercontent.com/dask-contrib/dask-sql/main/dask_sql/sql.yaml :schema: https://raw.githubusercontent.com/dask-contrib/dask-sql/main/dask_sql/sql-schema.yaml ================================================ FILE: docs/source/custom.rst ================================================ .. _custom: Custom Functions and Aggregations ================================= Additional to the included SQL functionalities, it is possible to include custom functions and aggregations into the SQL queries of ``dask-sql``. The custom functions are classified into scalar functions and aggregations. If you want to combine Machine Learning with SQL, you might also be interested in :ref:`machine_learning`. Scalar Functions ---------------- A scalar function (such as :math:`x \to x^2`) turns a given column into another column of the same length. It can be registered for usage in SQL with the :func:`~dask_sql.Context.register_function` method. Example: .. code-block:: python def f(x): return x ** 2 c.register_function(f, "f", [("x", np.int64)], np.int64) The registration gives a name to the function and also adds type information on the input types and names, as well as the return type. All usual numpy types (e.g. ``np.int64``) and pandas types (``Int64``) are supported. After registration, the function can be used as any other usual SQL function: .. code-block:: python c.sql("SELECT f(column) FROM data") Scalar functions can have one or more input parameters and can combine columns and literal values. Row-Wise Pandas UDFs -------------------- In some cases it may be easier to write custom functions which process a dict like row object, such as those consumed by ``pandas.DataFrame.apply``. These functions may be registered as above and flagged as row UDFs using the `row_udf` keyword argument: .. code-block:: python def f(row): return row['a'] + row['b'] c.register_function(f, "f", [("a", np.int64), ("b", np.int64)], np.int64, row_udf=True) c.sql("SELECT f(a, b) FROM data") ** Note: Row UDFs use `apply` which may have unpredictable performance characteristics, depending on the function and dataframe library ** UDFs written in this way can also be extended to accept scalar arguments along with the incoming row: .. code-block:: python def f(row, k): return row['a'] + k c.register_function(f, "f", [("a", np.int64), ("k", np.int64)], np.int64, row_udf=True) c.sql("SELECT f(a, 42) FROM data") Aggregation Functions --------------------- Aggregation functions run on a single column and turn them into a single value. This means they can only be used in ``GROUP BY`` aggregations. They can be registered with the :func:`~dask_sql.Context.register_aggregation` method. This time however, an instance of a :class:`dask.dataframe.Aggregation` needs to be passed instead of a plain function. More information on dask aggregations can be found in the `dask documentation `_. Example: .. code-block:: python my_sum = dd.Aggregation("my_sum", lambda x: x.sum(), lambda x: x.sum()) c.register_aggregation(my_sum, "my_sum", [("x", np.float64)], np.float64) c.sql("SELECT my_sum(other_colum) FROM df GROUP BY column") .. note:: There can only ever exist a single function with the same name. No matter if this is an aggregation function or a scalar function. ================================================ FILE: docs/source/data_input.rst ================================================ .. _data_input: Data Loading and Input ====================== Before data can be queried with ``dask-sql``, it needs to be loaded into the Dask cluster (or local instance) and registered with the :class:`~dask_sql.Context`. ``dask-sql`` supports all ``dask``-compatible `input formats `_, plus some additional formats only suitable for ``dask-sql``. 1. Load it via Python --------------------- You can either use already created Dask DataFrames or create one by using the :func:`~dask_sql.Context.create_table` function. Chances are high, there exists already a function to load your favorite format or location (e.g. S3 or hdfs). See below for all formats understood by ``dask-sql``. Make sure to install required libraries both on the driver and worker machines: .. tabs:: .. group-tab:: CPU .. code-block:: python import dask.dataframe as dd from dask_sql import Context c = Context() df = dd.read_csv("s3://nyc-tlc/trip data/yellow_tripdata_2019-01.csv") c.create_table("my_data", df) .. group-tab:: GPU .. code-block:: python import dask.dataframe as dd from dask_sql import Context c = Context() df = dd.read_csv("s3://nyc-tlc/trip data/yellow_tripdata_2019-01.csv") c.create_table("my_data", df, gpu=True) or in short (equivalent): .. tabs:: .. group-tab:: CPU .. code-block:: python from dask_sql import Context c = Context() c.create_table("my_data", "s3://nyc-tlc/trip data/yellow_tripdata_2019-01.csv") .. group-tab:: GPU .. code-block:: python from dask_sql import Context c = Context() c.create_table("my_data", "s3://nyc-tlc/trip data/yellow_tripdata_2019-01.csv", gpu=True) 2. Load it via SQL ------------------ If you are connected to the SQL server implementation or you do not want to issue Python command calls, you can also achieve the data loading via SQL only. .. tabs:: .. group-tab:: CPU .. code-block:: sql CREATE TABLE my_data WITH ( format = 'csv', location = 's3://nyc-tlc/trip data/yellow_tripdata_2019-01.csv' ) .. group-tab:: GPU .. code-block:: sql CREATE TABLE my_data WITH ( format = 'csv', location = 's3://nyc-tlc/trip data/yellow_tripdata_2019-01.csv', gpu = True ) The parameters are the same as in the Python function described above. You can find more information in :ref:`creation`. 3. Persist and share data on the cluster ---------------------------------------- In ``dask``, you can publish datasets with names into the cluster memory. This allows to reuse the same data from multiple clients/users in multiple sessions. For example, you can publish your data using the ``client.publish_dataset`` function of the ``distributed.Client``, and then later register it in the :class:`~dask_sql.Context` via SQL: .. code-block:: python # a dask.distributed Client client = Client(...) client.publish_dataset(my_df=df) Later in SQL: .. tabs:: .. group-tab:: CPU .. code-block:: SQL CREATE TABLE my_data WITH ( format = 'memory', location = 'my_df' ) .. group-tab:: GPU .. code-block:: SQL CREATE TABLE my_data WITH ( format = 'memory', location = 'my_df', gpu = True ) Note, that the format is set to ``memory`` and the location is the name, which was chosen when publishing the dataset. To achieve the same thing from Python, you can just use Dask's methods to get the dataset .. tabs:: .. group-tab:: CPU .. code-block:: python df = client.get_dataset("my_df") c.create_table("my_data", df) .. group-tab:: GPU .. code-block:: python df = client.get_dataset("my_df") c.create_table("my_data", df, gpu=True) Input Formats ------------- ``dask-sql`` understands (thanks to the large Dask ecosystem) a wide verity of input formats and input locations. * All formats and locations mentioned in `the Dask documentation `_, including CSV, Parquet, and JSON. Just pass in the location as string (and possibly the format, e.g. "csv" if it is not clear from the file extension). The data can be from local disc or many remote locations (S3, hdfs, Azure Filesystem, http, Google Filesystem, ...) - just prefix the path with the matching protocol. Additional arguments passed to :func:`~dask_sql.Context.create_table` or ``CREATE TABLE`` are given to the ``read_`` calls. Example: .. tabs:: .. group-tab:: CPU .. code-block:: python c.create_table( "my_data", "s3://bucket-name/my-data-*.csv", storage_options={'anon': True} ) .. code-block:: sql CREATE TABLE my_data WITH ( format = 'csv', -- can also be omitted, as clear from the extension location = 's3://bucket-name/my-data-*.csv', storage_options = ( anon = True ) ) .. group-tab:: GPU .. code-block:: python c.create_table( "my_data", "s3://bucket-name/my-data-*.csv", gpu=True, storage_options={'anon': True} ) .. code-block:: sql CREATE TABLE my_data WITH ( format = 'csv', -- can also be omitted, as clear from the extension location = 's3://bucket-name/my-data-*.csv', gpu = True, storage_options = ( anon = True ) ) * If your data is already in Pandas (or Dask) DataFrames format, you can just use it as it is via the Python API by giving it to :func:`~dask_sql.Context.create_table` directly. * You can connect ``dask-sql`` to an `intake `_ catalog and use the data registered there. Assuming you have an intake catalog stored in "catalog.yaml" (can also be the URL of an intake server), you can read in a stored table "data_table" either via Python .. code-block:: python catalog = intake.open_catalog("catalog.yaml") c.create_table("my_data", catalog, intake_table_name="intake_table") # or c.create_table("my_data", "catalog.yaml", format="intake", intake_table_name="intake_table") or via SQL: .. code-block:: sql CREATE TABLE my_data WITH ( format = 'intake', location = 'catalog.yaml' ) The argument ``intake_table_name`` is optional and defaults to the table name in ``dask_sql``. With the argument ``catalog_kwargs`` you can control how the intake catalog object is created. Additional arguments are forwarded to the ``to_dask()`` call of intake. * As an experimental feature, it is also possible to use data stored in the `Apache Hive `_ metastore. For this, ``dask-sql`` will retrieve the information on the storage location and format from the metastore and will then register the raw data directly in the context. This means, no Hive data query will be issued and you might be able to see a speed improvement. It is both possible to use a `pyhive.hive.Cursor` or an `sqlalchemy` connection. .. code-block:: python from dask_sql import Context from pyhive.hive import connect import sqlalchemy c = Context() cursor = connect("hive-server", 10000).cursor() # or cursor = sqlalchemy.create_engine("hive://hive-server:10000").connect() c.create_table("my_data", cursor, hive_table_name="the_name_in_hive") or in SQL: .. code-block:: sql CREATE TABLE my_data WITH ( location = 'hive://hive-server:10000', hive_table_name = 'the_name_in_hive' ) Again, ``hive_table_name`` is optional and defaults to the table name in ``dask-sql``. You can also control the database used in Hive via the ``hive_schema_name`` parameter. Additional arguments are pushed to the internally called ``read_`` functions. * Similarly, it is possible to load data from a `Databricks Cluster `_ (which is similar to a Hive metastore). You need to have the ``databricks-dbapi`` package installed and ``fsspec >= 0.8.7``. A token needs to be `generated `_ for the accessing user. The ``host``, ``port`` and ``http_path`` information can be found in the JDBC tab of the cluster. .. code-block:: python from dask_sql import Context from sqlalchemy import create_engine c = Context() cursor = create_engine(f"databricks+pyhive://token:{token}@{host}:{port}/", connect_args={"http_path": http_path}).connect() c.create_table("my_data", cursor, hive_table_name="schema.table", storage_options={"instance": host, "token": token}) or in SQL .. code-block:: sql CREATE TABLE my_data WITH ( location = 'databricks+pyhive://token:{token}@{host}:{port}/', connect_args = ( http_path = '{http_path}' ), hive_table_name = 'schema.table', storage_options = ( instance = '{host}', token = '{token}' ) ) .. note:: For ``dask-sql`` it does not matter how you load your data. In all shown cases you can then use the specified table name to query your data in a ``SELECT`` call. Please note however that un-persisted data will be reread from its source (e.g. on S3 or disk) on every query whereas persisted data is only read once. This will increase the query speed, but will also prevent you from seeing external updates to your data (until you reload it explicitly). ================================================ FILE: docs/source/fugue.rst ================================================ FugueSQL Integrations ===================== `FugueSQL `_ is a related project that aims to provide a unified SQL interface for a variety of different computing frameworks, including Dask. While it offers a SQL engine with a larger set of supported commands, this comes at the cost of slower performance when using Dask in comparison to dask-sql. In order to offer a "best of both worlds" solution, dask-sql includes several options to integrate with FugueSQL, using its faster implementation of SQL commands when possible and falling back on FugueSQL when necessary. dask-sql as a FugueSQL engine ----------------------------- FugueSQL users unfamiliar with dask-sql can take advantage of its functionality by installing it in an environment alongside Fugue; this will automatically register :class:`dask_sql.integrations.fugue.DaskSQLExecutionEngine` as the default Dask execution engine for FugueSQL queries. For more information and sample usage, see `Fugue — dask-sql as a FugueSQL engine `_. Using FugueSQL on an existing ``Context`` ----------------------------------------- dask-sql users attempting to expand their SQL querying options for an existing ``Context`` can use :func:`dask_sql.integrations.fugue.fsql_dask`, which executes the provided query using FugueSQL, using the tables within the provided context as input. The results of this query can then optionally be registered to the context: .. code-block:: python # define a custom prepartition function for FugueSQL def median(df: pd.DataFrame) -> pd.DataFrame: df["y"] = df["y"].median() return df.head(1) # create a context with some tables c = Context() ... # run a FugueSQL query using the context as input query = """ j = SELECT df1.*, df2.x FROM df1 INNER JOIN df2 ON df1.key = df2.key PERSIST TAKE 5 ROWS PREPARTITION BY x PRESORT key PRINT TRANSFORM j PREPARTITION BY x USING median PRINT """ result = fsql_dask(query, c, register=True) # results aren't registered by default assert "j" in result # returns a dict of resulting tables assert "j" in c.tables # results are also registered to the context ================================================ FILE: docs/source/how_does_it_work.rst ================================================ How does it work? ================= At the core, ``dask-sql`` does two things: - Translates the SQL query using `Apache Arrow DataFusion `_ into a relational algebra, represented by a `LogicalPlan enum `_ - similar to many other SQL engines (Hive, Flink, ...) - Converts this description of the query from the Rust enum into Dask API calls (and executes them) - returning a Dask dataframe. The following example explains this in quite some technical details. For most of the users, this level of technical understanding is not needed. 1. SQL enters the library ------------------------- No matter of via the Python API (:ref:`api`), the command line client (:ref:`cmd`) or the server (:ref:`server`), eventually the SQL statement by the user will end up as a string in the function :func:`~dask_sql.Context.sql`. 2. SQL is parsed ---------------- This function will first give the SQL string to the dask_planner Rust crate via the ``PyO3`` library. Inside this crate, Apache Arrow DataFusion is used to first parse the SQL string and then turn it into a relational algebra. For this, DataFusion uses the SQL language description specified in the `sqlparser-rs library `_ We also include `SQL extensions specific to Dask-SQL `_. They specify custom language features, such as the ``CREATE MODEL`` statement. 3. SQL is (maybe) optimized --------------------------- Once the SQL string is parsed into a :class:`Statement` enum, DataFusion can convert it into a relational algebra represented by a `LogicalPlan enum `_ and optimize it. As this is only implemented for DataFusion supported syntax (and not for the custom syntax such as :class:`SqlCreateModel`) this conversion and optimization is not triggered for all SQL statements (have a look into :func:`Context._get_ral`). The logical plan is a tree structure and most enum variants (such as :class:`Projection` or :class:`Join`) can contain other instances as "inputs" creating a tree of different steps in the SQL statement (see below for an example). The result is an optimized :class:`LogicalPlan`. 4. Translation to Dask API calls -------------------------------- Each step in the :class:`LogicalPlan` is converted into calls to Python functions using different Python "converters". For each enum variant (such as :class:`Projection` and :class:`Join`), there exist a converter class in the ``dask_sql.physical.rel`` folder, which are registered at the :class:`dask_sql.physical.rel.convert.RelConverter` class. Their job is to use the information stored in the logical plan enum variants and turn it into calls to Python functions (see the example below for more information). As many SQL statements contain calculations using literals and/or columns, these are split into their own functionality (``dask_sql.physical.rex``) following a similar plugin-based converter system. Have a look into the specific classes to understand how the conversion of a specific SQL language feature is implemented. 5. Result --------- The result of each of the conversions is a :class:`dask.DataFrame`, which is given to the user. In case of the command line tool or the SQL server, it is evaluated immediately - otherwise it can be used for further calculations by the user. Example ------- Let's walk through the steps above using the example SQL statement .. code-block:: sql SELECT x + y FROM timeseries WHERE x > 0 assuming the table "timeseries" is already registered. If you want to follow along with the steps outlined in the following, start the command line tool in debug mode .. code-block:: bash dask-sql --load-test-data --startup --log-level DEBUG and enter the SQL statement above. First, the SQL is parsed by DataFusion and (as it is not a custom statement) transformed into a tree of relational algebra objects. .. code-block:: none Projection: #timeseries.x + #timeseries.y Filter: #timeseries.x > Float64(0) TableScan: timeseries projection=[x, y] The tree output above means, that the outer instance (:class:`Projection`) needs as input the output of the previous instance (:class:`Filter`) etc. Therefore the conversion to Python API calls is called recursively (depth-first). First, the :class:`LogicalTableScan` is converted using the :class:`rel.logical.table_scan.LogicalTableScanPlugin` plugin. It will just get the correct :class:`dask.DataFrame` from the dictionary of already registered tables of the context. Next, the :class:`LogicalFilter` (having the dataframe as input), is converted via the :class:`rel.logical.filter.LogicalFilterPlugin`. The filter expression ``>($3, 0)`` is converted into ``df["x"] > 0`` using a combination of REX plugins (have a look into the debug output to learn more) and applied to the dataframe. The resulting dataframe is then passed to the converter :class:`rel.logical.project.LogicalProjectPlugin` for the :class:`LogicalProject`. This will calculate the expression ``df["x"] + df["y"]`` (after having converted it via the class:`RexCallPlugin` plugin) and return the final result to the user. .. code-block:: python df_table_scan = context.tables["timeseries"] df_filter = df_table_scan[df_table_scan["x"] > 0] df_project = df_filter.assign(col=df_filter["x"] + df_filter["y"]) return df_project[["col"]] ================================================ FILE: docs/source/index.rst ================================================ dask-sql ======== ``dask-sql`` is a distributed SQL query engine in Python. It allows you to query and transform your data using a mixture of common SQL operations and Python code and also scale up the calculation easily if you need it. * **Combine the power of Python and SQL**: load your data with Python, transform it with SQL, enhance it with Python and query it with SQL - or the other way round. With ``dask-sql`` you can mix the well known Python dataframe API of `pandas` and ``Dask`` with common SQL operations, to process your data in exactly the way that is easiest for you. * **Infinite Scaling**: using the power of the great ``Dask`` ecosystem, your computations can scale as you need it - from your laptop to your super cluster - without changing any line of SQL code. From k8s to cloud deployments, from batch systems to YARN - if ``Dask`` `supports it `_, so will ``dask-sql``. * **Your data - your queries**: Use Python user-defined functions (UDFs) in SQL without any performance drawback and extend your SQL queries with the large number of Python libraries, e.g. machine learning, different complicated input formats, complex statistics. * **Easy to install and maintain**: ``dask-sql`` is just a pip/conda install away (or a docker run if you prefer). * **Use SQL from wherever you like**: ``dask-sql`` integrates with your jupyter notebook, your normal Python module or can be used as a standalone SQL server from any BI tool. It even integrates natively with `Apache Hue `_. * **GPU Support**: ``dask-sql`` has support for running SQL queries on CUDA-enabled GPUs by utilizing `RAPIDS `_ libraries like `cuDF `_ , enabling accelerated compute for SQL. Example ------- For this example, we use some data loaded from disk and query it with a SQL command. ``dask-sql`` accepts any pandas, cuDF, or dask dataframe as input and is able to read data directly from a variety of storage formats (CSV, Parquet, JSON) and file systems (S3, hdfs, gcs): .. tabs:: .. group-tab:: CPU .. code-block:: python import dask.datasets from dask_sql import Context # create a context to register tables c = Context() # create a table and register it in the context df = dask.datasets.timeseries() c.create_table("timeseries", df) # execute a SQL query; the result is a "lazy" Dask dataframe result = c.sql(""" SELECT name, SUM(x) as "sum" FROM timeseries GROUP BY name """) # actually compute the query... result.compute() # ...or use it for another computation result["sum"].mean().compute() .. group-tab:: GPU .. code-block:: python import dask.datasets from dask_sql import Context # create a context to register tables c = Context() # create a table and register it in the context df = dask.datasets.timeseries() c.create_table("timeseries", df, gpu=True) # execute a SQL query; the result is a "lazy" Dask dataframe result = c.sql(""" SELECT name, SUM(x) as "sum" FROM timeseries GROUP BY name """) # actually compute the query... result.compute() # ...or use it for another computation result["sum"].mean().compute() .. toctree:: :maxdepth: 1 :caption: Contents: installation quickstart sql data_input custom machine_learning best_practices api server cmd fugue how_does_it_work configuration .. note:: ``dask-sql`` is currently under development and does so far not understand all SQL commands. We are actively looking for feedback, improvements and contributors! ================================================ FILE: docs/source/installation.rst ================================================ .. _installation: Installation ============ ``dask-sql`` can be installed via ``conda`` (preferred) or ``pip`` - or in a development environment. You can continue with the :ref:`quickstart` after the installation. With ``conda`` -------------- Create a new conda environment or use your already present environment: .. code-block:: bash conda create -n dask-sql conda activate dask-sql Install the package from the ``conda-forge`` channel: .. code-block:: bash conda install dask-sql -c conda-forge GPU support ^^^^^^^^^^^ - GPU support is currently tied to the `RAPIDS `_ libraries. - It generally requires the latest `cuDF/Dask-cuDF `_ nightlies. Create a new conda environment or use an existing one to install RAPIDS with the chosen methods and packages. More details can be found on the `RAPIDS Getting Started `_ page, but as an example: .. code-block:: bash conda create --name rapids-env -c rapidsai-nightly -c nvidia -c conda-forge \ cudf=22.10 dask-cudf=22.10 ucx-py ucx-proc=*=gpu python=3.9 cudatoolkit=11.8 conda activate rapids-env Note that using UCX is mainly necessary if you have an Infiniband or NVLink enabled system. Refer to the `UCX-Py docs `_ for more information. Install the stable package from the ``conda-forge`` channel: .. code-block:: bash conda install -c conda-forge dask-sql Or the latest nightly from the ``dask`` channel (currently only available for Linux-based operating systems): .. code-block:: bash conda install -c dask/label/dev dask-sql With ``pip`` ------------ .. code-block:: bash pip install dask-sql For development --------------- If you want to have the newest (unreleased) ``dask-sql`` version or if you plan to do development on ``dask-sql``, you can also install the package from sources. .. code-block:: bash git clone https://github.com/dask-contrib/dask-sql.git Create a new conda environment and install the development environment: .. code-block:: bash conda env create -f continuous_integration/environment-3.9.yaml It is not recommended to use ``pip`` instead of ``conda``. After that, you can install the package in development mode .. code-block:: bash pip install -e ".[dev]" To compile the Rust code (after changes), the above command must be rerun. You can run the tests (after installation) with .. code-block:: bash pytest tests GPU-specific tests require additional dependencies specified in `continuous_integration/gpuci/environment.yaml`: .. code-block:: bash conda env create -n dask-sql-gpuci -f continuous_integration/gpuci/environment.yaml GPU-specific tests can be run with .. code-block:: bash pytest tests -m gpu --rungpu This repository uses pre-commit hooks. To install them, call .. code-block:: bash pre-commit install ================================================ FILE: docs/source/machine_learning.rst ================================================ .. _machine_learning: Machine Learning ================ .. note:: Machine Learning support is experimental in ``dask-sql``. We encourage you to try it out and report any issues on our `issue tracker `_. Both the training as well as the prediction using Machine Learning methods play a crucial role in many data analytics applications. ``dask-sql`` supports Machine Learning applications in different ways, depending on how much you would like to do in Python or SQL. Please also see :ref:`ml` for more information on the SQL statements used on this page. 1. Data Preparation in SQL, Training and Prediction in Python ------------------------------------------------------------- If you are familiar with Python and the ML ecosystem in Python, this one is probably the simplest possibility. You can use the :func:`~dask_sql.Context.sql` call as described before to extract the data for your training or ML prediction. The result will be a Dask dataframe, which you can either directly feed into your model or convert to a pandas dataframe with `.compute()` before. This gives you full control on the training process and the simplicity of using SQL for data manipulation. You can use this method in your Python scripts or Jupyter Notebooks, but not from the :ref:`server` or :ref:`cmd`. 2. Training in Python, Prediction in SQL ---------------------------------------- In many companies/teams, it is typical that some team members are responsible for creating/training a ML model, and others use it to predict unseen data. It would be possible to create a custom function (see :ref:`custom`) to load and use the model, which then can be used in ``SELECT`` queries. However for convenience, ``dask-sql`` introduces a SQL keyword to do this work for you automatically. The syntax is similar to the `BigQuery Predict Syntax `_. .. code-block:: python c.sql(""" SELECT * FROM PREDICT ( MODEL my_model, SELECT x, y, z FROM data ) """) This call will first collect the data from the inner ``SELECT`` call (which can be any valid ``SELECT`` call, including ``JOIN``, ``WHERE``, ``GROUP BY``, custom tables and views etc.) and will then apply the model with the name "my_model" for prediction. The model needs to be registered at the context before using :func:`~dask_sql.Context.register_model`. .. code-block:: python c.register_model("my_model", model) The model registered here can be any valid python object, which follows the scikit-learn interface, which is to have a ``predict()`` function. Please note that the input will not be pandas dataframe, but a Dask dataframe. See :ref:`ml` for more information. 3. Training and Prediction in SQL --------------------------------- This method, in contrast to the other two possibilities, works completely from SQL, which allows you to also call it e.g. from your BI tool. Additionally to the ``PREDICT`` keyword mentioned above, ``dask-sql`` also has a way to create and train a model from SQL: .. code-block:: sql CREATE MODEL my_model WITH ( model_class = 'LogisticRegression', wrap_predict = True, target_column = 'target' ) AS ( SELECT x, y, x*y > 0 as target FROM timeseries LIMIT 100 ) This call will create a new instance of ``sklearn.linear_model.LogisticRegression`` or ``cuml.linear_model.LogisticRegression`` (the full path is inferred by Dask-SQL depending on whether you are using a CPU or GPU DataFrame) and train it with the data collected from the ``SELECT`` call (again, every valid ``SELECT`` query can be given). The model can than be used in subsequent calls to ``PREDICT`` using the given name. We explicitly set ``wrap_predict`` = ``True`` here to parallelize post fit prediction task of non distributed models (sklearn/cuML etc) across workers, although in this case ``wrap_predict`` would have already defaulted to ``True`` for the sklearn model. Have a look into :ref:`ml` for more information. 4. Check Model parameters - Model meta data ------------------------------------------- After the model was trained, you can inspect and get model details by using the following SQL statements .. code-block:: sql -- show the list of models which are trained and stored in the context. SHOW MODELS -- To get the hyperparameters of the trained MODEL, use -- DESCRIBE MODEL . DESCRIBE MODEL my_model 5. Hyperparameter Tuning ------------------------- Want to increase the performance of your model by tuning the parameters? Use the hyperparameter tuning directly in SQL using below SQL syntax, choose different tuners based on memory and compute constraints. .. TODO - add a GPU section to these examples once we have working CREATE EXPERIMENT tests for GPU .. code-block:: sql CREATE EXPERIMENT my_exp WITH ( model_class = 'GradientBoostingClassifier', experiment_class = 'GridSearchCV', tune_parameters = (n_estimators = ARRAY [16, 32, 2], learning_rate = ARRAY [0.1,0.01,0.001], max_depth = ARRAY [3,4,5,10] ), experiment_kwargs = (n_jobs = -1), target_column = 'target' ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 ) In this case, we set ``n_jobs`` = ``-1`` to ensure that all jobs run in parallel. 5.1 AutoML in SQL ----------------- Want to try different models with different parameters in SQL? Now you can start AutoML experiments with the help of the ``tpot`` framework, which trains and evaluates a number of different sklearn-compatible models and uses Dask for distributing the work across the Dask clusters. Use below SQL syntax for AutoML and for more details refer to the `tpot automl framework `_ .. code-block:: sql CREATE EXPERIMENT my_exp WITH ( automl_class = 'tpot.TPOTClassifier', automl_kwargs = (population_size = 2 , generations=2, cv=2, n_jobs=-1, use_dask=True, max_eval_time_mins=1), target_column = 'target' ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 ) After the experiment was completed, both hyperparameter tuner and AutoML experiments stores the best model of the experiment in the SQL context with the name same as the experiment name, which can be used for prediction. 6. Export Trained Model ------------------------ Once your model was trained and performs good in your validation dataset, you can export the model into a file with one of the supported model serialization formats like Pickle, Joblib, MLflow (framework-agnostic serialization format), etc. Currently, Dask-SQL supports the Pickle, Joblib and MLflow format for exporting the trained model, which can then be deployed as microservices, etc. Before training and exporting the models from different framework like LightGBM or CatBoost, please ensure the relevant packages are installed in the Dask-SQL environment, otherwise it will raise an exception on import. If you are using MLflow, ensure MLflow is installed. Keep in mind that Dask-SQL supports only sklearn-compatible models (i.e fit-predict style models) so far, so instead of using ``xgb.core.Booster``, consider using ``xgboost.XGBClassifier`` since the latter is sklearn-compatible and used by Dask-SQL for training, predicting, and exporting the model through the standard sklearn interface. .. TODO - add a GPU section to these examples once we have working EXPORT MODEL tests for GPU .. code-block:: sql -- for pickle model serialization EXPORT MODEL my_model WITH ( format ='pickle', location = 'model.pkl' ) -- for joblib model serialization EXPORT MODEL my_model WITH ( format ='joblib', location = 'model.pkl' ) -- for mlflow model serialization EXPORT MODEL my_model WITH ( format ='mlflow', location = 'mlflow_dir' ) -- Note you can pass more number of key value pairs -- (parameters) which will be delegated to the respective -- export functions Example ~~~~~~~ The following SQL-only code gives an example on how the commands can play together. We assume that you have created/registered a table "my_data" with the numerical columns ``x`` and ``y`` and the boolean target ``label``. .. TODO - add a GPU section to these examples once we have working CREATE EXPERIMENT tests for GPU .. code-block:: sql -- First, we create a new feature z out of x and y. -- For convenience, we store it in another table CREATE OR REPLACE TABLE transformed_data AS ( SELECT x, y, x + y AS z, label FROM my_data ) -- We split the data into a training set -- by using the first 100 items. -- Please note that this is just for a very quick-and-dirty -- example - you would probably want to do something -- more advanced here, maybe with TABLESAMPLE CREATE OR REPLACE TABLE training_data AS ( SELECT * FROM transformed_data LIMIT 15 ) -- Quickly check the data SELECT * FROM training_data -- We can now train a model from the sklearn package. CREATE OR REPLACE MODEL my_model WITH ( model_class = 'sklearn.ensemble.GradientBoostingClassifier', wrap_predict = True, target_column = 'label' ) AS ( SELECT * FROM training_data ) -- Now apply the trained model on all the data -- and compare. SELECT *, (CASE WHEN target = label THEN True ELSE False END) AS correct FROM PREDICT(MODEL my_model, SELECT * FROM transformed_data ) -- list models SHOW MODELS -- check parameters of the model DESCRIBE MODEL my_model -- experiment to tune different hyperparameters CREATE EXPERIMENT my_exp WITH( model_class = 'sklearn.ensemble.GradientBoostingClassifier', experiment_class = 'sklearn.model_selection.GridSearchCV', tune_parameters = (n_estimators = ARRAY [16, 32, 2], learning_rate = ARRAY [0.1,0.01,0.001], max_depth = ARRAY [3,4,5,10] ), experiment_kwargs = (n_jobs = -1), target_column = 'label' ) AS ( SELECT * FROM training_data ) -- creates experiment with automl framework CREATE EXPERIMENT my_exp WITH ( automl_class = 'tpot.TPOTRegressor', automl_kwargs = (population_size = 2 , generations=2, cv=2, n_jobs=-1, use_dask=True, max_eval_time_mins=1), target_column = 'z' ) AS ( SELECT * FROM training_data ) -- checks the parameter of automl model DESCRIBE MODEL automl_TPOTRegressor -- export model EXPORT MODEL my_model WITH ( format ='pickle', location = 'model.pkl' ) ================================================ FILE: docs/source/quickstart.rst ================================================ .. _quickstart: Quickstart ========== After :ref:`installation`, you can start querying your data using SQL. Run the following code in an interactive Python session, a Python script or a Jupyter Notebook. 0. Cluster Setup ---------------- If you just want to try out ``dask-sql`` quickly, this step can be skipped. However, the real magic of ``dask`` (and ``dask-sql``) comes from the ability to scale the computations over multiple cores and/or machines. For local development and testing, a Distributed ``LocalCluster`` (or, if using GPUs, a `Dask-CUDA `_ ``LocalCUDACluster``) can be deployed and a client connected to it like so: .. tabs:: .. group-tab:: CPU .. code-block:: python from distributed import Client, LocalCluster cluster = LocalCluster() client = Client(cluster) .. group-tab:: GPU .. code-block:: python from dask_cuda import LocalCUDACluster from distributed import Client cluster = LocalCUDACluster() client = Client(cluster) There are several options for deploying clusters depending on the platform being used and the resources available; see `Dask - Deploying Clusters `_ for more information. 1. Data Loading --------------- Before querying the data, you need to create a ``dask`` `data frame `_ containing the data. ``dask`` understands many different `input formats `_ and sources. In this example, we do not read in external data, but use test data in the form of random event time series: .. code-block:: python import dask.datasets df = dask.datasets.timeseries() Read more on the data input part in :ref:`data_input`. 2. Data Registration -------------------- If we want to work with the data in SQL, we need to give the data frame a unique name. We do this by registering the data in an instance of a :class:`~dask_sql.Context`: .. tabs:: .. group-tab:: CPU .. code-block:: python from dask_sql import Context c = Context() c.create_table("timeseries", df) .. group-tab:: GPU .. code-block:: python from dask_sql import Context c = Context() c.create_table("timeseries", df, gpu=True) From now on, the data is accessible as the ``timeseries`` table of this context. It is possible to register multiple data frames in the same context. .. hint:: If you plan to query the same data multiple times, it might make sense to persist the data before: .. tabs:: .. group-tab:: CPU .. code-block:: python c.create_table("timeseries", df, persist=True) .. group-tab:: GPU .. code-block:: python c.create_table("timeseries", df, persist=True, gpu=True) 3. Run your queries ------------------- Now you can go ahead and query the data with normal SQL! .. code-block:: python result = c.sql(""" SELECT name, SUM(x) AS "sum" FROM timeseries WHERE x > 0.5 GROUP BY name """) result.compute() ``dask-sql`` understands a large fraction of SQL commands, but there are still some missing. Have a look into the :ref:`sql` description for more information. If you are using ``dask-sql`` from a Jupyter notebook, you might be interested in the ``sql`` magic function: .. code-block:: python c.ipython_magic() %%sql SELECT name, SUM(x) AS "sum" FROM timeseries WHERE x > 0.5 GROUP BY name .. note:: If you have found an SQL feature, which is currently not supported by ``dask-sql``, please raise an issue on our `issue tracker `_. ================================================ FILE: docs/source/server.rst ================================================ .. _server: SQL Server ========== ``dask-sql`` comes with a small test implementation for a SQL server. Instead of rebuilding a full ODBC driver, we re-use the `presto wire protocol `_. .. note:: It is - so far - only a start of the development and missing important concepts, such as authentication. You can test the sql presto server by running (after installation) .. code-block:: bash dask-sql-server or by running these lines of code .. code-block:: python from dask_sql import run_server run_server() or directly with a created context .. code-block:: python c.run_server() or by using the created docker image .. code-block:: bash docker run --rm -it -p 8080:8080 nbraun/dask-sql This will spin up a server on port 8080 (by default). The port and bind interfaces can be controlled with the ``--port`` and ``--host`` command line arguments (or options to :func:`~dask_sql.run_server`). The running server looks similar to a normal presto database to any presto client and can therefore be used with any library, e.g. the `presto CLI client `_ or ``sqlalchemy`` via the `PyHive `_ package: .. code-block:: bash presto --server localhost:8080 Now you can fire simple SQL queries (as no data is loaded by default): .. code-block:: => SELECT 1 + 1; EXPR$0 -------- 2 (1 row) Or via ``sqlalchemy`` (after having installed ``PyHive``): .. code-block:: python from sqlalchemy.engine import create_engine engine = create_engine('presto://localhost:8080/') import pandas as pd pd.read_sql_query("SELECT 1 + 1", con=engine) Of course, it is also possible to call the usual ``CREATE TABLE`` commands. Preregister your own data sources --------------------------------- The python function :func:`~dask_sql.run_server` accepts an already created :class:`~dask_sql.Context`. This means you can preload your data sources and register them with a context before starting your server. By this, your server will already have data to query: .. code-block:: python from dask_sql import Context c = Context() c.create_table(...) # Then spin up the ``dask-sql`` server from dask_sql import run_server run_server(context=c) Run it in your own ``dask`` cluster ----------------------------------- The SQL server implementation in ``dask-sql`` allows you to run a SQL server as a service connected to your ``dask`` cluster. This enables your users to run SQL command leveraging the full power of your ``dask`` cluster without the need to write python code and allows also the usage of different non-python tools (such as BI tools) as long as they can speak the presto protocol. To run a standalone SQL server in your ``dask`` cluster, follow these three steps: 1. Create a startup script to connect ``dask-sql`` to your cluster. There exist many different ways to connect to a ``dask`` cluster (e.g. direct access to the scheduler, dask gateway, ...). Choose the one suitable for your cluster and create a small startup script: .. code-block:: python # Connect to your cluster here, e.g. from dask.distributed import Client client = Client(scheduler_address) ... # Then spin up the ``dask-sql`` server from dask_sql import run_server run_server(client=client) 2. Deploy this script to your cluster as a service. How you do this, depends on your cluster infrastructure (kubernetes, mesos, openshift, ...). For example you could create a docker image with a dockerfile similar to this: .. code-block:: dockerfile FROM nbraun/dask-sql COPY continuous_integration/docker/startup_script.py /opt/dask_sql/startup_script.py ENTRYPOINT [ "/opt/conda/bin/python", "/opt/dask_sql/startup_script.py" ] 3. After your service is deployed, you can use it in your applications as a "normal" presto database. The ``dask-sql`` SQL server was successfully tested with `Apache Hue `_, `Apache Superset `_ and `Metabase `_. Running from a jupyter notebook ------------------------------- If you quickly want to bridge the gap between your jupyter notebook and a BI tool, you can run a temporary SQL server from your jupyter notebook. .. code-block:: python # Create a Context and work with it from dask_sql import Context c = Context() ... # Later create a temporary server c.run_server(blocking=False) # Continue working This allows you to access the same context with all its registered tables both in the jupyter notebook as well as by connecting to the SQL server started on port 8080 (e.g. with your BI tool). Once you are done with the SQL server, you can close it with .. code-block:: python c.stop_server() Please note that this feature should not be used for productive SQL servers, but just for quick analyses via an external application. ================================================ FILE: docs/source/sql/creation.rst ================================================ .. _creation: Table Creation ============== As described in :ref:`quickstart`, it is possible to register an already created dask dataframe with a call to ``c.create_table``. However, it is also possible to load data directly from disk (or s3, hdfs, URL, hive, ...) and register it as a table in ``dask_sql``. Behind the scenes, a call to one of the ``read_`` of the ``dask.dataframe`` will be executed. Additionally, queries can be materialized into new tables for caching or faster access. .. raw:: html
    CREATE [ OR REPLACE ] TABLE [ IF NOT EXISTS ] <table-name>
        WITH ( <key> = <value> [ , ... ] )
    CREATE [ OR REPLACE ] TABLE [ IF NOT EXISTS ] <table-name>
        AS ( SELECT ... )
    CREATE [ OR REPLACE ] VIEW [ IF NOT EXISTS ] <table-name>
        AS ( SELECT ... )
    DROP TABLE | VIEW [ IF EXISTS ] <table-name>
    
See :ref:`sql` for information on how to reference tables correctly. Please note, that there can only ever exist a single view or table with the same name. .. note:: As there is only a single schema "schema" in ``dask-sql``, table names should not include a separator "." in ``CREATE`` calls. By default, if a table with the same name does already exist, ``dask-sql`` will raise an exception (and in turn will raise an exception if you try to delete a table which is not present). With the flags ``IF [NOT] EXISTS`` and ``OR REPLACE``, this behavior can be controlled: * ``CREATE OR REPLACE TABLE | VIEW`` will override an already present table/view with the same name without raising an exception. * ``CREATE TABLE IF NOT EXISTS`` will not create the table/view if it already exists (and will also not raise an exception). * ``DROP TABLE | VIEW IF EXISTS`` will only drop the table/view if it exists and will not do anything otherwise. ``CREATE TABLE WITH`` --------------------- This will create and register a new table "df" with the data under the specified location and format. For information on how to specify key-value arguments properly, see :ref:`sql`. With the ``persist`` parameter, it can be controlled if the data should be cached or re-read for every SQL query. The additional parameters are passed to the particular data loading functions. If you omit the format argument, it will be deduced from the file name extension. More ways to load data can be found in :ref:`data_input`. Example: .. raw:: html
CREATE TABLE df WITH (
        location = "/some/file/path",
        format = "csv/parquet/json/...",
        persist = True,
        additional_parameter = value,
        ...
    )
    
``CREATE TABLE AS`` ------------------- Using a similar syntax, it is also possible to create a (materialized) view of a (maybe complicated) SQL query. With the command, you give the result of the ``SELECT`` query a name, that you can use in subsequent calls. The ``SELECT`` can also contain a call to ``PREDICT``, see :ref:`ml`. Example: .. code-block:: sql CREATE TABLE my_table AS ( SELECT a, b, SUM(c) FROM data GROUP BY a, b ... ) SELECT * FROM my_table ``CREATE VIEW AS`` ------------------ Instead of using ``CREATE TABLE`` it is also possible to use ``CREATE VIEW``. The result is very similar, the only difference is, *when* the result will be computed: a view is recomputed on every usage, whereas a table is only calculated once on creation (also known as a materialized view). This means, if you e.g. read data from a remote file and the file changes, a query containing a view will be updated whereas a query with a table will stay as it is. To update a table, you need to recreate it. .. hint:: Use views to simplify complicated queries (like a "shortcut") and tables for caching. .. note:: The update of the view only works, if your primary data source (the files you were reading in), are not persisted during reading. Example: .. code-block:: sql CREATE VIEW my_table AS ( SELECT a, b, SUM(c) FROM data GROUP BY a, b ... ) SELECT * FROM my_table ``DROP TABLE | VIEW`` --------------------- Remove a table or view with the given name. Please note again, that views and tables are treated equally, so ``CREATE TABLE`` will also delete the view with the given name and vise versa. ================================================ FILE: docs/source/sql/describe.rst ================================================ Metadata Information ==================== With these operations, it is possible to get information on the currently registered tables and their columns. The output format is mostly compatible with the presto format. .. raw:: html
    SHOW SCHEMAS
    SHOW TABLES FROM <schema-name>
    SHOW COLUMNS FROM <table-name>
    DESCRIBE <table-name>
    ANALYZE TABLE <table-name> COMPUTE STATISTICS
        [ FOR ALL COLUMNS | FOR COLUMNS <column>, [ ,... ] ]
    
See :ref:`sql` for information on how to reference schemas and tables correctly. ``SHOW SCHEMAS`` ---------------- Show the schemas registered in ``dask-sql``. Only included for compatibility reasons. There is always just a one called "schema", where all the data is located and an additional schema, called "information_schema", which is needed by some BI tools (which is empty). Example: .. raw:: html
    SHOW SCHEMAS
    
Result: +------------------------+ | Schema | +========================+ | schema | +------------------------+ | information_schema | +------------------------+ ``SHOW TABLES`` --------------- Show the registered tables in a given schema. Example: .. raw:: html
    SHOW TABLES FROM "schema"
    
Result: +------------+ | Table | +============+ | timeseries | +------------+ ``SHOW COLUMNS`` and ``DESCRIBE`` --------------------------------- Show column information on a specific table. Example: .. raw:: html
    SHOW COLUMNS FROM "timeseries"
    
Result: +--------+---------+---------------+ | Column | Type | Extra Comment | +========+=========+===============+ | id | bigint | | +--------+---------+---------------+ | name | varchar | | +--------+---------+---------------+ | x | double | | +--------+---------+---------------+ | y | double | | +--------+---------+---------------+ The column "Extra Comment" is shown for compatibility with presto. ``ANALYZE TABLE`` ----------------- Calculate statistics on a given table (and the given columns or all columns) and return it as a query result. Please note, that this process can be time consuming on large tables. Even though this statement is very similar to the ``ANALYZE TABLE`` statement in e.g. `Apache Spark `_, it does not optimize subsequent queries (as the pendent in Spark will do). Example: .. raw:: html
    ANALYZE TABLE "timeseries" COMPUTE STATISTICS FOR COLUMNS x, y
    
Result: +-----------+-----------+-----------+ | | x | y | +===========+===========+===========+ | count | 30 | 30 | +-----------+-----------+-----------+ | mean | 0.140374 | -0.107481 | +-----------+-----------+-----------+ | std | 0.568248 | 0.573106 | +-----------+-----------+-----------+ | min | -0.795112 | -0.966043 | +-----------+-----------+-----------+ | 25% | -0.379635 | -0.561234 | +-----------+-----------+-----------+ | 50% | 0.0104101 | -0.237795 | +-----------+-----------+-----------+ | 75% | 0.70208 | 0.263459 | +-----------+-----------+-----------+ | max | 0.990747 | 0.947069 | +-----------+-----------+-----------+ | data_type | double | double | +-----------+-----------+-----------+ | col_name | x | y | +-----------+-----------+-----------+ ================================================ FILE: docs/source/sql/ml.rst ================================================ .. _ml: Machine Learning in SQL ======================= .. note:: Machine Learning support is experimental in ``dask-sql``. We encourage you to try it out and report any issues on our `issue tracker `_. As all SQL statements in ``dask-sql`` are eventually converted to Python calls, it is very simple to include any custom Python function and library, e.g. Machine Learning libraries. Although it would be possible to register custom functions (see :ref:`custom`) for this and use them, it is much more convenient if this functionality is already included in the core SQL language. These three statements help in training and using models. Every :class:`~dask_sql.Context` has a registry for models, which can be used for training or prediction. For a full example, see :ref:`machine_learning`. .. raw:: html
    CREATE [ OR REPLACE ] MODEL [ IF NOT EXISTS ] <model-name>
        WITH ( <key> = <value> [ , ... ] ) AS ( SELECT ... )
    DROP MODEL [ IF EXISTS ] <model-name>
    SELECT <expression> FROM PREDICT (MODEL <model-name>, SELECT ... )
    
``IF [ NOT ] EXISTS`` and ``CREATE OR REPLACE`` behave similar to its analogous flags in ``CREATE TABLE``. See :ref:`creation` for more information. ``CREATE MODEL`` ---------------- Create and train a model on the data from the given ``SELECT`` query and register it at the context. The select query is a normal ``SELECT`` query (following the same syntax as described in :ref:`select`) or even a call to ``PREDICT`` (which typically does not make sense however) and its result is used as the training data. The key-value parameters control, how and which model is trained: * ``model_class``: This argument needs to be present. It is the class name or full python module path to the class of the model to train. Any sklearn, cuML, XGBoost, or LightGBM classes can be inferred without the full path. In this case, models trained on cuDF dataframes are automatically mapped to cuML classes, and sklearn models otherwise. We map to cuML-Dask based models when possible and single-GPU cuML models otherwise. Any model class with sklearn interface is valid, but might or might not work well with Dask dataframes. You might need to install necessary packages to use the models. * ``target_column``: Which column from the data to use as target. If not empty, it is removed automatically from the training data. Defaults to an empty string, in which case no target is feed to the model training (e.g. for unsupervised algorithms). This means, you typically want to set this parameter. * ``wrap_predict``: Boolean flag, whether to wrap the selected model with a :class:`dask_sql.physical.rel.custom.wrappers.ParallelPostFit`. Defaults to true for sklearn and single GPU cuML models and false otherwise. Typically you set it to true for sklearn models if predicting on big data. * ``wrap_fit``: Boolean flag, whether to wrap the selected model with a :class:`dask_sql.physical.rel.custom.wrappers.Incremental`. Defaults to true for sklearn and single GPU cuML models and false otherwise. Typically you set it to true for sklearn models if training on big data. * ``fit_kwargs``: keyword arguments sent to the call to ``fit()``. All other arguments are passed to the constructor of the model class. Example: .. raw:: html
CREATE MODEL my_model WITH (
        model_class = 'XGBClassifier',
        target_column = 'target'
    ) AS (
        SELECT x, y, target
        FROM "data"
    )
    
This SQL call is not a 1:1 replacement for a normal python training and can not fulfill all use-cases or requirements! If you are dealing with large amounts of data, you might run into problems while model training and/or prediction, depending if your model can cope with dask dataframes. * if you are training on relatively small amounts of data but predicting on large data samples, you might want to set ``wrap_predict`` to True. With this option, model interference will be parallelized/distributed. * If you are training on large amounts of data, you can try setting wrap_fit to True. This will do the same on the training step, but works only on those models, which have a ``fit_partial`` method. ``DROP MODEL`` -------------- Remove the model with the given name from the registered models. ``SELECT FROM PREDICT`` ----------------------- Predict the target using the given model and dataframe from the ``SELECT`` query. The return value is the input dataframe with an additional column named "target", which contains the predicted values. The model needs to be registered at the context before using it in this function, either by calling :func:`~dask_sql.Context.register_model` explicitly or by training a model using the ``CREATE MODEL`` SQL statement above. A model can be anything which has a ``predict`` function. Please note however, that it will need to act on Dask dataframes. If you are using a model not optimized for this, it might be that you run out of memory if your data is larger than the RAM of a single machine. To prevent this, have a look into the :class:`dask_sql.physical.rel.custom.wrappers.ParallelPostFit` meta-estimator. If you are using a model trained with ``CREATE MODEL`` and the ``wrap_predict`` flag set to true, this is done automatically. Using this SQL statement is roughly equivalent to doing .. code-block:: python df = context.sql("
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
a
0xyz
\n", "" ], "text/plain": [ " a\n", "0 xyz" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "schema: a:str" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%%fsql dask\n", "\n", "CREATE [[\"xyz\"], [\"xxx\"]] SCHEMA a:str\n", "SELECT * WHERE a LIKE '%y%'\n", "PRINT" ] }, { "cell_type": "markdown", "id": "7f16b7d9-6b45-4caf-bbcb-63cc5d858556", "metadata": {}, "source": [ "We can also use the `YIELD` keyword to register the results of our queries into Python objects:" ] }, { "cell_type": "code", "execution_count": 4, "id": "521965bc-1a4c-49ab-b48f-789351cb24d4", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
b
0xyz
1xxx-
\n", "
" ], "text/plain": [ " b\n", "0 xyz\n", "1 xxx-" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "schema: b:str" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%%fsql dask\n", "src = CREATE [[\"xyz\"], [\"xxx\"]] SCHEMA a:str\n", "\n", "a = SELECT a AS b WHERE a LIKE '%y%'\n", " YIELD DATAFRAME AS test\n", "\n", "b = SELECT CONCAT(a, '-') AS b FROM src WHERE a LIKE '%xx%'\n", " YIELD DATAFRAME AS test1\n", "\n", "SELECT * FROM a UNION SELECT * FROM b\n", "PRINT" ] }, { "cell_type": "markdown", "id": "dfbb0a9a", "metadata": {}, "source": [ "Which can then be interacted with outside of SQL:" ] }, { "cell_type": "code", "execution_count": 5, "id": "79a3e87a-2764-410c-b257-c710c4a6c6d4", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Dask DataFrame Structure:
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
b
npartitions=2
object
...
...
\n", "
\n", "
Dask Name: rename, 16 tasks
" ], "text/plain": [ "Dask DataFrame Structure:\n", " b\n", "npartitions=2 \n", " object\n", " ...\n", " ...\n", "Dask Name: rename, 16 tasks" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test.native # a Dask DataFrame" ] }, { "cell_type": "code", "execution_count": 6, "id": "c98cb652-06e2-444a-b70a-fdd3de9ecd15", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
b
1xxx-
\n", "
" ], "text/plain": [ " b\n", "1 xxx-" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test1.native.compute()" ] }, { "cell_type": "markdown", "id": "932ede31-90b2-49e5-9f4d-7cf1b8d919d2", "metadata": {}, "source": [ "We can also run the equivalent of these queries in python code using `fugue_sql.fsql`, passing the Dask client into its `run` method to specify Dask as an execution engine:" ] }, { "cell_type": "code", "execution_count": 7, "id": "c265b170-de4d-4fab-aeae-9f94031e960d", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
a
0xyz
\n", "
" ], "text/plain": [ " a\n", "0 xyz" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "schema: a:str" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "DataFrames()" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from fugue_sql import fsql\n", "\n", "fsql(\"\"\"\n", "CREATE [[\"xyz\"], [\"xxx\"]] SCHEMA a:str\n", "SELECT * WHERE a LIKE '%y%'\n", "PRINT\n", "\"\"\").run(client)" ] }, { "cell_type": "code", "execution_count": 8, "id": "77e3bf50-8c8b-4e2f-a5e7-28b1d86499d7", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Dask DataFrame Structure:
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
a
npartitions=2
object
...
...
\n", "
\n", "
Dask Name: rename, 16 tasks
" ], "text/plain": [ "Dask DataFrame Structure:\n", " a\n", "npartitions=2 \n", " object\n", " ...\n", " ...\n", "Dask Name: rename, 16 tasks" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "result = fsql(\"\"\"\n", "CREATE [[\"xyz\"], [\"xxx\"]] SCHEMA a:str\n", "SELECT * WHERE a LIKE '%y%'\n", "YIELD DATAFRAME AS test2\n", "\"\"\").run(client)\n", "\n", "result[\"test2\"].native # a Dask DataFrame" ] }, { "cell_type": "code", "execution_count": null, "id": "7d4c71d4-238f-4c72-8609-dbbe0782aea9", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" }, "vscode": { "interpreter": { "hash": "656801d214ad98d4b301386b078628ce3ae2dbd81a59ed4deed7a5b13edfab09" } } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: notebooks/iris.csv ================================================ sepal_length,sepal_width,petal_length,petal_width,species 5.1,3.5,1.4,0.2,setosa 4.9,3.0,1.4,0.2,setosa 4.7,3.2,1.3,0.2,setosa 4.6,3.1,1.5,0.2,setosa 5.0,3.6,1.4,0.2,setosa 5.4,3.9,1.7,0.4,setosa 4.6,3.4,1.4,0.3,setosa 5.0,3.4,1.5,0.2,setosa 4.4,2.9,1.4,0.2,setosa 4.9,3.1,1.5,0.1,setosa 5.4,3.7,1.5,0.2,setosa 4.8,3.4,1.6,0.2,setosa 4.8,3.0,1.4,0.1,setosa 4.3,3.0,1.1,0.1,setosa 5.8,4.0,1.2,0.2,setosa 5.7,4.4,1.5,0.4,setosa 5.4,3.9,1.3,0.4,setosa 5.1,3.5,1.4,0.3,setosa 5.7,3.8,1.7,0.3,setosa 5.1,3.8,1.5,0.3,setosa 5.4,3.4,1.7,0.2,setosa 5.1,3.7,1.5,0.4,setosa 4.6,3.6,1.0,0.2,setosa 5.1,3.3,1.7,0.5,setosa 4.8,3.4,1.9,0.2,setosa 5.0,3.0,1.6,0.2,setosa 5.0,3.4,1.6,0.4,setosa 5.2,3.5,1.5,0.2,setosa 5.2,3.4,1.4,0.2,setosa 4.7,3.2,1.6,0.2,setosa 4.8,3.1,1.6,0.2,setosa 5.4,3.4,1.5,0.4,setosa 5.2,4.1,1.5,0.1,setosa 5.5,4.2,1.4,0.2,setosa 4.9,3.1,1.5,0.1,setosa 5.0,3.2,1.2,0.2,setosa 5.5,3.5,1.3,0.2,setosa 4.9,3.1,1.5,0.1,setosa 4.4,3.0,1.3,0.2,setosa 5.1,3.4,1.5,0.2,setosa 5.0,3.5,1.3,0.3,setosa 4.5,2.3,1.3,0.3,setosa 4.4,3.2,1.3,0.2,setosa 5.0,3.5,1.6,0.6,setosa 5.1,3.8,1.9,0.4,setosa 4.8,3.0,1.4,0.3,setosa 5.1,3.8,1.6,0.2,setosa 4.6,3.2,1.4,0.2,setosa 5.3,3.7,1.5,0.2,setosa 5.0,3.3,1.4,0.2,setosa 7.0,3.2,4.7,1.4,versicolor 6.4,3.2,4.5,1.5,versicolor 6.9,3.1,4.9,1.5,versicolor 5.5,2.3,4.0,1.3,versicolor 6.5,2.8,4.6,1.5,versicolor 5.7,2.8,4.5,1.3,versicolor 6.3,3.3,4.7,1.6,versicolor 4.9,2.4,3.3,1.0,versicolor 6.6,2.9,4.6,1.3,versicolor 5.2,2.7,3.9,1.4,versicolor 5.0,2.0,3.5,1.0,versicolor 5.9,3.0,4.2,1.5,versicolor 6.0,2.2,4.0,1.0,versicolor 6.1,2.9,4.7,1.4,versicolor 5.6,2.9,3.6,1.3,versicolor 6.7,3.1,4.4,1.4,versicolor 5.6,3.0,4.5,1.5,versicolor 5.8,2.7,4.1,1.0,versicolor 6.2,2.2,4.5,1.5,versicolor 5.6,2.5,3.9,1.1,versicolor 5.9,3.2,4.8,1.8,versicolor 6.1,2.8,4.0,1.3,versicolor 6.3,2.5,4.9,1.5,versicolor 6.1,2.8,4.7,1.2,versicolor 6.4,2.9,4.3,1.3,versicolor 6.6,3.0,4.4,1.4,versicolor 6.8,2.8,4.8,1.4,versicolor 6.7,3.0,5.0,1.7,versicolor 6.0,2.9,4.5,1.5,versicolor 5.7,2.6,3.5,1.0,versicolor 5.5,2.4,3.8,1.1,versicolor 5.5,2.4,3.7,1.0,versicolor 5.8,2.7,3.9,1.2,versicolor 6.0,2.7,5.1,1.6,versicolor 5.4,3.0,4.5,1.5,versicolor 6.0,3.4,4.5,1.6,versicolor 6.7,3.1,4.7,1.5,versicolor 6.3,2.3,4.4,1.3,versicolor 5.6,3.0,4.1,1.3,versicolor 5.5,2.5,4.0,1.3,versicolor 5.5,2.6,4.4,1.2,versicolor 6.1,3.0,4.6,1.4,versicolor 5.8,2.6,4.0,1.2,versicolor 5.0,2.3,3.3,1.0,versicolor 5.6,2.7,4.2,1.3,versicolor 5.7,3.0,4.2,1.2,versicolor 5.7,2.9,4.2,1.3,versicolor 6.2,2.9,4.3,1.3,versicolor 5.1,2.5,3.0,1.1,versicolor 5.7,2.8,4.1,1.3,versicolor 6.3,3.3,6.0,2.5,virginica 5.8,2.7,5.1,1.9,virginica 7.1,3.0,5.9,2.1,virginica 6.3,2.9,5.6,1.8,virginica 6.5,3.0,5.8,2.2,virginica 7.6,3.0,6.6,2.1,virginica 4.9,2.5,4.5,1.7,virginica 7.3,2.9,6.3,1.8,virginica 6.7,2.5,5.8,1.8,virginica 7.2,3.6,6.1,2.5,virginica 6.5,3.2,5.1,2.0,virginica 6.4,2.7,5.3,1.9,virginica 6.8,3.0,5.5,2.1,virginica 5.7,2.5,5.0,2.0,virginica 5.8,2.8,5.1,2.4,virginica 6.4,3.2,5.3,2.3,virginica 6.5,3.0,5.5,1.8,virginica 7.7,3.8,6.7,2.2,virginica 7.7,2.6,6.9,2.3,virginica 6.0,2.2,5.0,1.5,virginica 6.9,3.2,5.7,2.3,virginica 5.6,2.8,4.9,2.0,virginica 7.7,2.8,6.7,2.0,virginica 6.3,2.7,4.9,1.8,virginica 6.7,3.3,5.7,2.1,virginica 7.2,3.2,6.0,1.8,virginica 6.2,2.8,4.8,1.8,virginica 6.1,3.0,4.9,1.8,virginica 6.4,2.8,5.6,2.1,virginica 7.2,3.0,5.8,1.6,virginica 7.4,2.8,6.1,1.9,virginica 7.9,3.8,6.4,2.0,virginica 6.4,2.8,5.6,2.2,virginica 6.3,2.8,5.1,1.5,virginica 6.1,2.6,5.6,1.4,virginica 7.7,3.0,6.1,2.3,virginica 6.3,3.4,5.6,2.4,virginica 6.4,3.1,5.5,1.8,virginica 6.0,3.0,4.8,1.8,virginica 6.9,3.1,5.4,2.1,virginica 6.7,3.1,5.6,2.4,virginica 6.9,3.1,5.1,2.3,virginica 5.8,2.7,5.1,1.9,virginica 6.8,3.2,5.9,2.3,virginica 6.7,3.3,5.7,2.5,virginica 6.7,3.0,5.2,2.3,virginica 6.3,2.5,5.0,1.9,virginica 6.5,3.0,5.2,2.0,virginica 6.2,3.4,5.4,2.3,virginica 5.9,3.0,5.1,1.8,virginica ================================================ FILE: pyproject.toml ================================================ [build-system] requires = ["maturin>=1.3,<1.4"] build-backend = "maturin" [project] name = "dask_sql" description = "SQL query layer for Dask" maintainers = [{name = "Nils Braun", email = "nilslennartbraun@gmail.com"}] license = {text = "MIT"} classifiers = [ "Development Status :: 5 - Production/Stable", "Intended Audience :: Developers", "Intended Audience :: Science/Research", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", "Programming Language :: Rust", "Programming Language :: Python", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Topic :: Scientific/Engineering", "Topic :: System :: Distributed Computing", ] readme = "README.md" requires-python = ">=3.9" dependencies = [ "dask[dataframe]>=2024.4.1", "distributed>=2024.4.1", "pandas>=1.4.0", "fastapi>=0.92.0", "httpx>=0.24.1", "uvicorn>=0.14", "tzlocal>=2.1", "prompt_toolkit>=3.0.8", "pygments>=2.7.1", "tabulate", ] dynamic = ["version"] [project.urls] Homepage = "https://github.com/dask-contrib/dask-sql" Documentation = "https://dask-sql.readthedocs.io" Source = "https://github.com/dask-contrib/dask-sql" [project.optional-dependencies] dev = [ "pytest>=6.0.1", "pytest-cov>=2.10.1", "mock>=4.0.3", "sphinx>=3.2.1", "pyarrow>=14.0.1", "scikit-learn>=1.0.0", "intake>=0.6.0", "pre-commit", "black==22.10.0", "isort==5.12.0", ] fugue = [ "fugue>=0.7.3", # FIXME: https://github.com/fugue-project/fugue/issues/526 "triad<0.9.2", ] [project.entry-points."fugue.plugins"] dasksql = "dask_sql.integrations.fugue:_register_engines[fugue]" [project.scripts] dask-sql = "dask_sql.cmd:main" dask-sql-server = "dask_sql.server.app:main" [tool.setuptools] include-package-data = true zip-safe = false license-files = ["LICENSE.txt"] [tool.setuptools.packages] find = {namespaces = false} [tool.maturin] module-name = "dask_sql._datafusion_lib" include = [ { path = "Cargo.lock", format = "sdist" } ] exclude = [".github/**", "continuous_integration/**"] locked = true [tool.isort] profile = "black" [tool.pytest.ini_options] markers = [ "gpu: marks tests that require GPUs (skipped by default, run with --rungpu)", "queries: marks tests that run test queries (skipped by default, run with --runqueries)", ] addopts = "-v -rsxfE --color=yes --cov dask_sql --cov-config=.coveragerc --cov-report=term-missing" filterwarnings = [ "error:::dask_sql[.*]", "error:::dask[.*]", "ignore:Need to do a cross-join:ResourceWarning:dask_sql[.*]", "ignore:Dask doesn't support Dask frames:ResourceWarning:dask_sql[.*]", "ignore:Running on a single-machine scheduler:UserWarning:dask[.*]", "ignore:Merging dataframes with merge column data type mismatches:UserWarning:dask[.*]", ] xfail_strict = true ================================================ FILE: rustfmt.toml ================================================ imports_layout = "HorizontalVertical" imports_granularity = "Crate" group_imports = "StdExternalCrate" ================================================ FILE: setup.cfg ================================================ [flake8] # References: # https://flake8.readthedocs.io/en/latest/user/configuration.html # https://flake8.readthedocs.io/en/latest/user/error-codes.html # https://pycodestyle.pycqa.org/en/latest/intro.html#error-codes exclude = __init__.py ignore = E203, # whitespace before ':' E231,E241, # Multiple spaces around "," E731, # Assigning lambda expression #E741, # Ambiguous variable names W503, # line break before binary operator W504, # line break after binary operator ; F821, # undefined name per-file-ignores = tests/*: # local variable is assigned to but never used F841, # Ambiguous variable name E741, max-line-length = 150 ================================================ FILE: src/dialect.rs ================================================ use core::{iter::Peekable, str::Chars}; use datafusion_python::datafusion_sql::sqlparser::{ ast::{Expr, Function, FunctionArg, FunctionArgExpr, Ident, ObjectName, Value}, dialect::Dialect, keywords::Keyword, parser::{Parser, ParserError}, tokenizer::Token, }; #[derive(Debug)] pub struct DaskDialect {} impl Dialect for DaskDialect { fn is_identifier_start(&self, ch: char) -> bool { // See https://www.postgresql.org/docs/11/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS // We don't yet support identifiers beginning with "letters with // diacritical marks and non-Latin letters" ch.is_ascii_lowercase() || ch.is_ascii_uppercase() || ch == '_' } fn is_identifier_part(&self, ch: char) -> bool { ch.is_ascii_lowercase() || ch.is_ascii_uppercase() || ch.is_ascii_digit() || ch == '$' || ch == '_' } /// Determine if a character starts a quoted identifier. The default /// implementation, accepting "double quoted" ids is both ANSI-compliant /// and appropriate for most dialects (with the notable exception of /// MySQL, MS SQL, and sqlite). You can accept one of characters listed /// in `Word::matching_end_quote` here fn is_delimited_identifier_start(&self, ch: char) -> bool { ch == '"' } /// Determine if quoted characters are proper for identifier fn is_proper_identifier_inside_quotes(&self, mut _chars: Peekable>) -> bool { true } /// Determine if FILTER (WHERE ...) filters are allowed during aggregations fn supports_filter_during_aggregation(&self) -> bool { true } /// override expression parsing fn parse_prefix(&self, parser: &mut Parser) -> Option> { fn parse_expr(parser: &mut Parser) -> Result, ParserError> { match parser.peek_token().token { Token::Word(w) if w.value.to_lowercase() == "ceil" => { // CEIL(d TO DAY) parser.next_token(); // skip ceil parser.expect_token(&Token::LParen)?; let expr = parser.parse_expr()?; if !parser.parse_keyword(Keyword::TO) { // Parse CEIL(expr) as normal parser.prev_token(); parser.prev_token(); parser.prev_token(); return Ok(None); } let time_unit = parser.next_token(); parser.expect_token(&Token::RParen)?; // convert to function args let args = vec![ FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)), FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value( Value::SingleQuotedString(time_unit.to_string()), ))), ]; Ok(Some(Expr::Function(Function { name: ObjectName(vec![Ident::new("timestampceil")]), args, over: None, distinct: false, special: false, order_by: vec![], }))) } Token::Word(w) if w.value.to_lowercase() == "floor" => { // FLOOR(d TO DAY) parser.next_token(); // skip floor parser.expect_token(&Token::LParen)?; let expr = parser.parse_expr()?; if !parser.parse_keyword(Keyword::TO) { // Parse FLOOR(expr) as normal parser.prev_token(); parser.prev_token(); parser.prev_token(); return Ok(None); } let time_unit = parser.next_token(); parser.expect_token(&Token::RParen)?; // convert to function args let args = vec![ FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)), FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value( Value::SingleQuotedString(time_unit.to_string()), ))), ]; Ok(Some(Expr::Function(Function { name: ObjectName(vec![Ident::new("timestampfloor")]), args, over: None, distinct: false, special: false, order_by: vec![], }))) } Token::Word(w) if w.value.to_lowercase() == "timestampadd" => { // TIMESTAMPADD(YEAR, 2, d) parser.next_token(); // skip timestampadd parser.expect_token(&Token::LParen)?; let time_unit = parser.next_token(); parser.expect_token(&Token::Comma)?; let n = parser.parse_expr()?; parser.expect_token(&Token::Comma)?; let expr = parser.parse_expr()?; parser.expect_token(&Token::RParen)?; // convert to function args let args = vec![ FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value( Value::SingleQuotedString(time_unit.to_string()), ))), FunctionArg::Unnamed(FunctionArgExpr::Expr(n)), FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)), ]; Ok(Some(Expr::Function(Function { name: ObjectName(vec![Ident::new("timestampadd")]), args, over: None, distinct: false, special: false, order_by: vec![], }))) } Token::Word(w) if w.value.to_lowercase() == "timestampdiff" => { parser.next_token(); // skip timestampdiff parser.expect_token(&Token::LParen)?; let time_unit = parser.next_token(); parser.expect_token(&Token::Comma)?; let expr1 = parser.parse_expr()?; parser.expect_token(&Token::Comma)?; let expr2 = parser.parse_expr()?; parser.expect_token(&Token::RParen)?; // convert to function args let args = vec![ FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value( Value::SingleQuotedString(time_unit.to_string()), ))), FunctionArg::Unnamed(FunctionArgExpr::Expr(expr1)), FunctionArg::Unnamed(FunctionArgExpr::Expr(expr2)), ]; Ok(Some(Expr::Function(Function { name: ObjectName(vec![Ident::new("timestampdiff")]), args, over: None, distinct: false, special: false, order_by: vec![], }))) } Token::Word(w) if w.value.to_lowercase() == "to_timestamp" => { // TO_TIMESTAMP(d, "%d/%m/%Y") parser.next_token(); // skip to_timestamp parser.expect_token(&Token::LParen)?; let expr = parser.parse_expr()?; let comma = parser.consume_token(&Token::Comma); let time_format = if comma { parser.next_token().to_string() } else { "%Y-%m-%d %H:%M:%S".to_string() }; parser.expect_token(&Token::RParen)?; // convert to function args let args = vec![ FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)), FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value( Value::SingleQuotedString(time_format), ))), ]; Ok(Some(Expr::Function(Function { name: ObjectName(vec![Ident::new("dsql_totimestamp")]), args, over: None, distinct: false, special: false, order_by: vec![], }))) } Token::Word(w) if w.value.to_lowercase() == "extract" => { // EXTRACT(DATE FROM d) parser.next_token(); // skip extract parser.expect_token(&Token::LParen)?; if !parser.parse_keywords(&[Keyword::DATE, Keyword::FROM]) { // Parse EXTRACT(x FROM d) as normal parser.prev_token(); parser.prev_token(); return Ok(None); } let expr = parser.parse_expr()?; parser.expect_token(&Token::RParen)?; // convert to function args let args = vec![ FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value( Value::SingleQuotedString("DATE".to_string()), ))), FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)), ]; Ok(Some(Expr::Function(Function { name: ObjectName(vec![Ident::new("extract_date")]), args, over: None, distinct: false, special: false, order_by: vec![], }))) } _ => Ok(None), } } match parse_expr(parser) { Ok(Some(expr)) => Some(Ok(expr)), Ok(None) => None, Err(e) => Some(Err(e)), } } } ================================================ FILE: src/error.rs ================================================ use std::fmt::{Display, Formatter}; use datafusion_python::{ datafusion_common::DataFusionError, datafusion_sql::sqlparser::{parser::ParserError, tokenizer::TokenizerError}, }; use pyo3::PyErr; pub type Result = std::result::Result; #[derive(Debug)] pub enum DaskPlannerError { DataFusionError(DataFusionError), ParserError(ParserError), TokenizerError(TokenizerError), Internal(String), InvalidIOFilter(String), } impl Display for DaskPlannerError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { Self::DataFusionError(e) => write!(f, "DataFusion Error: {e}"), Self::ParserError(e) => write!(f, "SQL Parser Error: {e}"), Self::TokenizerError(e) => write!(f, "SQL Tokenizer Error: {e}"), Self::Internal(e) => write!(f, "Internal Error: {e}"), Self::InvalidIOFilter(e) => write!(f, "Invalid pyarrow filter: {e} encountered. Defaulting to Dask CPU/GPU bound task operation"), } } } impl From for DaskPlannerError { fn from(err: TokenizerError) -> Self { Self::TokenizerError(err) } } impl From for DaskPlannerError { fn from(err: ParserError) -> Self { Self::ParserError(err) } } impl From for DaskPlannerError { fn from(err: DataFusionError) -> Self { Self::DataFusionError(err) } } impl From for PyErr { fn from(err: DaskPlannerError) -> PyErr { PyErr::new::(format!("{err:?}")) } } ================================================ FILE: src/expression.rs ================================================ use std::{borrow::Cow, convert::From, sync::Arc}; use datafusion_python::{ datafusion::arrow::datatypes::DataType, datafusion_common::{Column, DFField, DFSchema, ScalarValue}, datafusion_expr::{ expr::{ AggregateFunction, AggregateUDF, Alias, BinaryExpr, Cast, Exists, InList, InSubquery, ScalarFunction, ScalarUDF, Sort, TryCast, WindowFunction, }, lit, utils::exprlist_to_fields, Between, BuiltinScalarFunction, Case, Expr, GetIndexedField, Like, LogicalPlan, Operator, }, datafusion_sql::TableReference, }; use pyo3::prelude::*; use crate::{ error::{DaskPlannerError, Result}, sql::{ exceptions::{py_runtime_err, py_type_err}, logical, types::RexType, }, }; /// An PyExpr that can be used on a DataFrame #[pyclass(name = "Expression", module = "dask_sql", subclass)] #[derive(Debug, Clone)] pub struct PyExpr { pub expr: Expr, // Why a Vec here? Because BinaryExpr on Join might have multiple LogicalPlans pub input_plan: Option>>, } impl From for Expr { fn from(expr: PyExpr) -> Expr { expr.expr } } #[pyclass(name = "ScalarValue", module = "dask_sql", subclass)] #[derive(Debug, Clone)] pub struct PyScalarValue { pub scalar_value: ScalarValue, } impl From for ScalarValue { fn from(pyscalar: PyScalarValue) -> ScalarValue { pyscalar.scalar_value } } impl From for PyScalarValue { fn from(scalar_value: ScalarValue) -> PyScalarValue { PyScalarValue { scalar_value } } } /// Convert a list of DataFusion Expr to PyExpr pub fn py_expr_list(input: &Arc, expr: &[Expr]) -> PyResult> { Ok(expr .iter() .map(|e| PyExpr::from(e.clone(), Some(vec![input.clone()]))) .collect()) } impl PyExpr { /// Generally we would implement the `From` trait offered by Rust /// However in this case Expr does not contain the contextual /// `LogicalPlan` instance that we need so we need to make a instance /// function to take and create the PyExpr. pub fn from(expr: Expr, input: Option>>) -> PyExpr { PyExpr { input_plan: input, expr, } } /// Determines the name of the `Expr` instance by examining the LogicalPlan pub fn _column_name(&self, plan: &LogicalPlan) -> Result { let field = expr_to_field(&self.expr, plan)?; Ok(field.qualified_column().flat_name()) } fn _rex_type(&self, expr: &Expr) -> RexType { match expr { Expr::Alias(..) => RexType::Alias, Expr::Column(..) | Expr::QualifiedWildcard { .. } | Expr::GetIndexedField { .. } | Expr::Wildcard => RexType::Reference, Expr::ScalarVariable(..) | Expr::Literal(..) => RexType::Literal, Expr::BinaryExpr { .. } | Expr::Not(..) | Expr::IsNotNull(..) | Expr::Negative(..) | Expr::IsNull(..) | Expr::Like { .. } | Expr::SimilarTo { .. } | Expr::Between { .. } | Expr::Case { .. } | Expr::Cast { .. } | Expr::TryCast { .. } | Expr::Sort { .. } | Expr::ScalarFunction { .. } | Expr::AggregateFunction { .. } | Expr::WindowFunction { .. } | Expr::AggregateUDF { .. } | Expr::InList { .. } | Expr::ScalarUDF { .. } | Expr::Exists { .. } | Expr::InSubquery { .. } | Expr::GroupingSet(..) | Expr::IsTrue(..) | Expr::IsFalse(..) | Expr::IsUnknown(_) | Expr::IsNotTrue(..) | Expr::IsNotFalse(..) | Expr::Placeholder { .. } | Expr::OuterReferenceColumn(_, _) | Expr::IsNotUnknown(_) => RexType::Call, Expr::ScalarSubquery(..) => RexType::ScalarSubquery, } } } macro_rules! extract_scalar_value { ($self: expr, $variant: ident) => { match $self.get_scalar_value()? { ScalarValue::$variant(value) => Ok(*value), other => Err(unexpected_literal_value(other)), } }; } #[pymethods] impl PyExpr { #[staticmethod] pub fn literal(value: PyScalarValue) -> PyExpr { PyExpr::from(lit(value.scalar_value), None) } /// Extracts the LogicalPlan from a Subquery, or supported Subquery sub-type, from /// the expression instance #[pyo3(name = "getSubqueryLogicalPlan")] pub fn subquery_plan(&self) -> PyResult { match &self.expr { Expr::ScalarSubquery(subquery) => Ok(subquery.subquery.as_ref().clone().into()), Expr::InSubquery(insubquery) => { Ok(insubquery.subquery.subquery.as_ref().clone().into()) } _ => Err(py_type_err(format!( "Attempted to extract a LogicalPlan instance from invalid Expr {:?}. Only Subquery and related variants are supported for this operation.", &self.expr ))), } } /// If this Expression instances references an existing /// Column in the SQL parse tree or not #[pyo3(name = "isInputReference")] pub fn is_input_reference(&self) -> PyResult { Ok(matches!(&self.expr, Expr::Column(_col))) } #[pyo3(name = "toString")] pub fn to_string(&self) -> PyResult { Ok(format!("{}", &self.expr)) } /// Gets the positional index of the Expr instance from the LogicalPlan DFSchema #[pyo3(name = "getIndex")] pub fn index(&self) -> PyResult { let input: &Option>> = &self.input_plan; match input { Some(input_plans) if !input_plans.is_empty() => { let mut schema: DFSchema = (**input_plans[0].schema()).clone(); for plan in input_plans.iter().skip(1) { schema.merge(plan.schema().as_ref()); } let name = get_expr_name(&self.expr).map_err(py_runtime_err)?; if name != "*" { schema .index_of_column(&Column::from_qualified_name(name.clone())) .or_else(|_| { // Handles cases when from_qualified_name doesn't format the Column correctly. // "name" will always contain the name of the column. Anything in addition to // that will be separated by a '.' and should be further referenced. match &self.expr { Expr::Column(col) => { schema.index_of_column(col).map_err(py_runtime_err) } _ => { let parts = name.split('.').collect::>(); let tbl_reference = match parts.len() { // Single element means name contains just the column name so no TableReference 1 => None, // Tablename.column_name 2 => Some( TableReference::Bare { table: Cow::Borrowed(parts[0]), } .to_owned_reference(), ), // Schema_name.table_name.column_name 3 => Some( TableReference::Partial { schema: Cow::Borrowed(parts[0]), table: Cow::Borrowed(parts[1]), } .to_owned_reference(), ), // catalog_name.schema_name.table_name.column_name 4 => Some( TableReference::Full { catalog: Cow::Borrowed(parts[0]), schema: Cow::Borrowed(parts[1]), table: Cow::Borrowed(parts[2]), } .to_owned_reference(), ), _ => None, }; let col = Column { relation: tbl_reference.clone(), name: parts[parts.len() - 1].to_string(), }; schema.index_of_column(&col).map_err(py_runtime_err) } } }) } else { // Since this is wildcard any Column will do, just use first one Ok(0) } } _ => Err(py_runtime_err( "We need a valid LogicalPlan instance to get the Expr's index in the schema", )), } } /// Examine the current/"self" PyExpr and return its "type" /// In this context a "type" is what Dask-SQL Python /// RexConverter plugin instance should be invoked to handle /// the Rex conversion #[pyo3(name = "getExprType")] pub fn get_expr_type(&self) -> PyResult { Ok(String::from(match &self.expr { Expr::Alias(..) | Expr::Column(..) | Expr::Literal(..) | Expr::BinaryExpr { .. } | Expr::Between { .. } | Expr::Cast { .. } | Expr::Sort { .. } | Expr::ScalarFunction { .. } | Expr::AggregateFunction { .. } | Expr::InList { .. } | Expr::InSubquery { .. } | Expr::ScalarUDF { .. } | Expr::AggregateUDF { .. } | Expr::Exists { .. } | Expr::ScalarSubquery(..) | Expr::QualifiedWildcard { .. } | Expr::Not(..) | Expr::OuterReferenceColumn(_, _) | Expr::GroupingSet(..) => self.expr.variant_name(), Expr::ScalarVariable(..) | Expr::IsNotNull(..) | Expr::Negative(..) | Expr::GetIndexedField { .. } | Expr::IsNull(..) | Expr::IsTrue(_) | Expr::IsFalse(_) | Expr::IsUnknown(_) | Expr::IsNotTrue(_) | Expr::IsNotFalse(_) | Expr::Like { .. } | Expr::SimilarTo { .. } | Expr::IsNotUnknown(_) | Expr::Case { .. } | Expr::TryCast { .. } | Expr::WindowFunction { .. } | Expr::Placeholder { .. } | Expr::Wildcard => { return Err(py_type_err(format!( "Encountered unsupported expression type: {}", &self.expr.variant_name() ))) } })) } /// Determines the type of this Expr based on its variant #[pyo3(name = "getRexType")] pub fn rex_type(&self) -> PyResult { Ok(self._rex_type(&self.expr)) } /// Python friendly shim code to get the name of a column referenced by an expression pub fn column_name(&self, mut plan: logical::PyLogicalPlan) -> PyResult { self._column_name(&plan.current_node()) .map_err(py_runtime_err) } /// Row expressions, Rex(s), operate on the concept of operands. This maps to expressions that are used in /// the "call" logic of the Dask-SQL python codebase. Different variants of Expressions, Expr(s), /// store those operands in different datastructures. This function examines the Expr variant and returns /// the operands to the calling logic as a Vec of PyExpr instances. #[pyo3(name = "getOperands")] pub fn get_operands(&self) -> PyResult> { match &self.expr { // Expr variants that are themselves the operand to return Expr::Column(..) | Expr::ScalarVariable(..) | Expr::Literal(..) => { Ok(vec![PyExpr::from( self.expr.clone(), self.input_plan.clone(), )]) } // Expr(s) that house the Expr instance to return in their bounded params Expr::Not(expr) | Expr::IsNull(expr) | Expr::IsNotNull(expr) | Expr::IsTrue(expr) | Expr::IsFalse(expr) | Expr::IsUnknown(expr) | Expr::IsNotTrue(expr) | Expr::IsNotFalse(expr) | Expr::IsNotUnknown(expr) | Expr::Negative(expr) | Expr::GetIndexedField(GetIndexedField { expr, .. }) | Expr::Cast(Cast { expr, .. }) | Expr::TryCast(TryCast { expr, .. }) | Expr::Sort(Sort { expr, .. }) | Expr::InSubquery(InSubquery { expr, .. }) => { Ok(vec![PyExpr::from(*expr.clone(), self.input_plan.clone())]) } // Expr variants containing a collection of Expr(s) for operands Expr::AggregateFunction(AggregateFunction { args, .. }) | Expr::AggregateUDF(AggregateUDF { args, .. }) | Expr::ScalarFunction(ScalarFunction { args, .. }) | Expr::ScalarUDF(ScalarUDF { args, .. }) | Expr::WindowFunction(WindowFunction { args, .. }) => Ok(args .iter() .map(|arg| PyExpr::from(arg.clone(), self.input_plan.clone())) .collect()), // Expr(s) that require more specific processing Expr::Case(Case { expr, when_then_expr, else_expr, }) => { let mut operands: Vec = Vec::new(); if let Some(e) = expr { for (when, then) in when_then_expr { operands.push(PyExpr::from( Expr::BinaryExpr(BinaryExpr::new( Box::new(*e.clone()), Operator::Eq, Box::new(*when.clone()), )), self.input_plan.clone(), )); operands.push(PyExpr::from(*then.clone(), self.input_plan.clone())); } } else { for (when, then) in when_then_expr { operands.push(PyExpr::from(*when.clone(), self.input_plan.clone())); operands.push(PyExpr::from(*then.clone(), self.input_plan.clone())); } }; if let Some(e) = else_expr { operands.push(PyExpr::from(*e.clone(), self.input_plan.clone())); }; Ok(operands) } Expr::Alias(Alias { expr, .. }) => { Ok(vec![PyExpr::from(*expr.clone(), self.input_plan.clone())]) } Expr::InList(InList { expr, list, .. }) => { let mut operands: Vec = vec![PyExpr::from(*expr.clone(), self.input_plan.clone())]; for list_elem in list { operands.push(PyExpr::from(list_elem.clone(), self.input_plan.clone())); } Ok(operands) } Expr::BinaryExpr(BinaryExpr { left, right, .. }) => Ok(vec![ PyExpr::from(*left.clone(), self.input_plan.clone()), PyExpr::from(*right.clone(), self.input_plan.clone()), ]), Expr::Like(Like { expr, pattern, .. }) => Ok(vec![ PyExpr::from(*expr.clone(), self.input_plan.clone()), PyExpr::from(*pattern.clone(), self.input_plan.clone()), ]), Expr::SimilarTo(Like { expr, pattern, .. }) => Ok(vec![ PyExpr::from(*expr.clone(), self.input_plan.clone()), PyExpr::from(*pattern.clone(), self.input_plan.clone()), ]), Expr::Between(Between { expr, negated: _, low, high, }) => Ok(vec![ PyExpr::from(*expr.clone(), self.input_plan.clone()), PyExpr::from(*low.clone(), self.input_plan.clone()), PyExpr::from(*high.clone(), self.input_plan.clone()), ]), Expr::Wildcard => Ok(vec![PyExpr::from( self.expr.clone(), self.input_plan.clone(), )]), // Currently un-support/implemented Expr types for Rex Call operations Expr::GroupingSet(..) | Expr::OuterReferenceColumn(_, _) | Expr::QualifiedWildcard { .. } | Expr::ScalarSubquery(..) | Expr::Placeholder { .. } | Expr::Exists { .. } => Err(py_runtime_err(format!( "Unimplemented Expr type: {}", self.expr ))), } } #[pyo3(name = "getOperatorName")] pub fn get_operator_name(&self) -> PyResult { Ok(match &self.expr { Expr::BinaryExpr(BinaryExpr { left: _, op, right: _, }) => format!("{op}"), Expr::ScalarFunction(ScalarFunction { fun, args: _ }) => format!("{fun}"), Expr::ScalarUDF(ScalarUDF { fun, .. }) => fun.name.clone(), Expr::Cast { .. } => "cast".to_string(), Expr::Between { .. } => "between".to_string(), Expr::Case { .. } => "case".to_string(), Expr::IsNull(..) => "is null".to_string(), Expr::IsNotNull(..) => "is not null".to_string(), Expr::IsTrue(_) => "is true".to_string(), Expr::IsFalse(_) => "is false".to_string(), Expr::IsUnknown(_) => "is unknown".to_string(), Expr::IsNotTrue(_) => "is not true".to_string(), Expr::IsNotFalse(_) => "is not false".to_string(), Expr::IsNotUnknown(_) => "is not unknown".to_string(), Expr::InList { .. } => "in list".to_string(), Expr::InSubquery(..) => "in subquery".to_string(), Expr::Negative(..) => "negative".to_string(), Expr::Not(..) => "not".to_string(), Expr::Like(Like { negated, case_insensitive, .. }) => { format!( "{}{}like", if *negated { "not " } else { "" }, if *case_insensitive { "i" } else { "" } ) } Expr::SimilarTo(Like { negated, .. }) => { if *negated { "not similar to".to_string() } else { "similar to".to_string() } } _ => { return Err(py_type_err(format!( "Catch all triggered in get_operator_name: {:?}", &self.expr ))) } }) } /// Gets the ScalarValue represented by the Expression #[pyo3(name = "getType")] pub fn get_type(&self) -> PyResult { Ok(String::from(match &self.expr { Expr::BinaryExpr(BinaryExpr { left: _, op, right: _, }) => match op { Operator::Eq | Operator::NotEq | Operator::Lt | Operator::LtEq | Operator::Gt | Operator::GtEq | Operator::And | Operator::Or | Operator::IsDistinctFrom | Operator::IsNotDistinctFrom | Operator::RegexMatch | Operator::RegexIMatch | Operator::RegexNotMatch | Operator::RegexNotIMatch => "BOOLEAN", Operator::Plus | Operator::Minus | Operator::Multiply | Operator::Modulo => { "BIGINT" } Operator::Divide => "FLOAT", Operator::StringConcat => "VARCHAR", Operator::BitwiseShiftLeft | Operator::BitwiseShiftRight | Operator::BitwiseXor | Operator::BitwiseAnd | Operator::BitwiseOr => { // the type here should be the same as the type of the left expression // but we can only compute that if we have the schema available return Err(py_type_err( "Bitwise operators unsupported in get_type".to_string(), )); } Operator::AtArrow | Operator::ArrowAt => { todo!() } }, Expr::Literal(scalar_value) => match scalar_value { ScalarValue::Boolean(_value) => "Boolean", ScalarValue::Float32(_value) => "Float32", ScalarValue::Float64(_value) => "Float64", ScalarValue::Decimal128(_value, ..) => "Decimal128", ScalarValue::Decimal256(_, _, _) => "Decimal256", ScalarValue::Dictionary(..) => "Dictionary", ScalarValue::Int8(_value) => "Int8", ScalarValue::Int16(_value) => "Int16", ScalarValue::Int32(_value) => "Int32", ScalarValue::Int64(_value) => "Int64", ScalarValue::UInt8(_value) => "UInt8", ScalarValue::UInt16(_value) => "UInt16", ScalarValue::UInt32(_value) => "UInt32", ScalarValue::UInt64(_value) => "UInt64", ScalarValue::Utf8(_value) => "Utf8", ScalarValue::LargeUtf8(_value) => "LargeUtf8", ScalarValue::Binary(_value) => "Binary", ScalarValue::LargeBinary(_value) => "LargeBinary", ScalarValue::Date32(_value) => "Date32", ScalarValue::Date64(_value) => "Date64", ScalarValue::Time32Second(_value) => "Time32", ScalarValue::Time32Millisecond(_value) => "Time32", ScalarValue::Time64Microsecond(_value) => "Time64", ScalarValue::Time64Nanosecond(_value) => "Time64", ScalarValue::Null => "Null", ScalarValue::TimestampSecond(..) => "TimestampSecond", ScalarValue::TimestampMillisecond(..) => "TimestampMillisecond", ScalarValue::TimestampMicrosecond(..) => "TimestampMicrosecond", ScalarValue::TimestampNanosecond(..) => "TimestampNanosecond", ScalarValue::IntervalYearMonth(..) => "IntervalYearMonth", ScalarValue::IntervalDayTime(..) => "IntervalDayTime", ScalarValue::IntervalMonthDayNano(..) => "IntervalMonthDayNano", ScalarValue::List(..) => "List", ScalarValue::Struct(..) => "Struct", ScalarValue::FixedSizeBinary(_, _) => "FixedSizeBinary", ScalarValue::Fixedsizelist(..) => "Fixedsizelist", ScalarValue::DurationSecond(..) => "DurationSecond", ScalarValue::DurationMillisecond(..) => "DurationMillisecond", ScalarValue::DurationMicrosecond(..) => "DurationMicrosecond", ScalarValue::DurationNanosecond(..) => "DurationNanosecond", }, Expr::ScalarFunction(ScalarFunction { fun, args: _ }) => match fun { BuiltinScalarFunction::Abs => "Abs", BuiltinScalarFunction::DatePart => "DatePart", _ => { return Err(py_type_err(format!( "Catch all triggered for ScalarFunction in get_type; {fun:?}" ))) } }, Expr::Cast(Cast { expr: _, data_type }) => match data_type { DataType::Null => "NULL", DataType::Boolean => "BOOLEAN", DataType::Int8 | DataType::UInt8 => "TINYINT", DataType::Int16 | DataType::UInt16 => "SMALLINT", DataType::Int32 | DataType::UInt32 => "INTEGER", DataType::Int64 | DataType::UInt64 => "BIGINT", DataType::Float32 => "FLOAT", DataType::Float64 => "DOUBLE", DataType::Timestamp { .. } => "TIMESTAMP", DataType::Date32 | DataType::Date64 => "DATE", DataType::Time32(..) => "TIME32", DataType::Time64(..) => "TIME64", DataType::Duration(..) => "DURATION", DataType::Interval(..) => "INTERVAL", DataType::Binary => "BINARY", DataType::FixedSizeBinary(..) => "FIXEDSIZEBINARY", DataType::LargeBinary => "LARGEBINARY", DataType::Utf8 => "VARCHAR", DataType::LargeUtf8 => "BIGVARCHAR", DataType::List(..) => "LIST", DataType::FixedSizeList(..) => "FIXEDSIZELIST", DataType::LargeList(..) => "LARGELIST", DataType::Struct(..) => "STRUCT", DataType::Union(..) => "UNION", DataType::Dictionary(..) => "DICTIONARY", DataType::Decimal128(..) => "DECIMAL", DataType::Decimal256(..) => "DECIMAL", DataType::Map(..) => "MAP", _ => { return Err(py_type_err(format!( "Catch all triggered for Cast in get_type; {data_type:?}" ))) } }, _ => { return Err(py_type_err(format!( "Catch all triggered in get_type; {:?}", &self.expr ))) } })) } /// Gets the precision/scale represented by the Expression's decimal datatype #[pyo3(name = "getPrecisionScale")] pub fn get_precision_scale(&self) -> PyResult<(u8, i8)> { Ok(match &self.expr { Expr::Cast(Cast { expr: _, data_type }) => match data_type { DataType::Decimal128(precision, scale) | DataType::Decimal256(precision, scale) => { (*precision, *scale) } _ => { return Err(py_type_err(format!( "Catch all triggered for Cast in get_precision_scale; {data_type:?}" ))) } }, _ => { return Err(py_type_err(format!( "Catch all triggered in get_precision_scale; {:?}", &self.expr ))) } }) } #[pyo3(name = "getFilterExpr")] pub fn get_filter_expr(&self) -> PyResult> { // TODO refactor to avoid duplication match &self.expr { Expr::Alias(Alias { expr, .. }) => match expr.as_ref() { Expr::AggregateFunction(AggregateFunction { filter, .. }) | Expr::AggregateUDF(AggregateUDF { filter, .. }) => match filter { Some(filter) => { Ok(Some(PyExpr::from(*filter.clone(), self.input_plan.clone()))) } None => Ok(None), }, _ => Err(py_type_err( "getFilterExpr() - Non-aggregate expression encountered", )), }, Expr::AggregateFunction(AggregateFunction { filter, .. }) | Expr::AggregateUDF(AggregateUDF { filter, .. }) => match filter { Some(filter) => Ok(Some(PyExpr::from(*filter.clone(), self.input_plan.clone()))), None => Ok(None), }, _ => Err(py_type_err( "getFilterExpr() - Non-aggregate expression encountered", )), } } #[pyo3(name = "getFloat32Value")] pub fn float_32_value(&self) -> PyResult> { extract_scalar_value!(self, Float32) } #[pyo3(name = "getFloat64Value")] pub fn float_64_value(&self) -> PyResult> { extract_scalar_value!(self, Float64) } #[pyo3(name = "getDecimal128Value")] pub fn decimal_128_value(&mut self) -> PyResult<(Option, u8, i8)> { match self.get_scalar_value()? { ScalarValue::Decimal128(value, precision, scale) => Ok((*value, *precision, *scale)), other => Err(unexpected_literal_value(other)), } } #[pyo3(name = "getInt8Value")] pub fn int_8_value(&self) -> PyResult> { extract_scalar_value!(self, Int8) } #[pyo3(name = "getInt16Value")] pub fn int_16_value(&self) -> PyResult> { extract_scalar_value!(self, Int16) } #[pyo3(name = "getInt32Value")] pub fn int_32_value(&self) -> PyResult> { extract_scalar_value!(self, Int32) } #[pyo3(name = "getInt64Value")] pub fn int_64_value(&self) -> PyResult> { extract_scalar_value!(self, Int64) } #[pyo3(name = "getUInt8Value")] pub fn uint_8_value(&self) -> PyResult> { extract_scalar_value!(self, UInt8) } #[pyo3(name = "getUInt16Value")] pub fn uint_16_value(&self) -> PyResult> { extract_scalar_value!(self, UInt16) } #[pyo3(name = "getUInt32Value")] pub fn uint_32_value(&self) -> PyResult> { extract_scalar_value!(self, UInt32) } #[pyo3(name = "getUInt64Value")] pub fn uint_64_value(&self) -> PyResult> { extract_scalar_value!(self, UInt64) } #[pyo3(name = "getDate32Value")] pub fn date_32_value(&self) -> PyResult> { extract_scalar_value!(self, Date32) } #[pyo3(name = "getDate64Value")] pub fn date_64_value(&self) -> PyResult> { extract_scalar_value!(self, Date64) } #[pyo3(name = "getTime64Value")] pub fn time_64_value(&self) -> PyResult> { extract_scalar_value!(self, Time64Nanosecond) } #[pyo3(name = "getTimestampValue")] pub fn timestamp_value(&mut self) -> PyResult<(Option, Option)> { match self.get_scalar_value()? { ScalarValue::TimestampNanosecond(iv, tz) | ScalarValue::TimestampMicrosecond(iv, tz) | ScalarValue::TimestampMillisecond(iv, tz) | ScalarValue::TimestampSecond(iv, tz) => match tz { Some(time_zone) => Ok((*iv, Some(time_zone.to_string()))), None => Ok((*iv, None)), }, other => Err(unexpected_literal_value(other)), } } #[pyo3(name = "getBoolValue")] pub fn bool_value(&self) -> PyResult> { extract_scalar_value!(self, Boolean) } #[pyo3(name = "getStringValue")] pub fn string_value(&self) -> PyResult> { match self.get_scalar_value()? { ScalarValue::Utf8(value) => Ok(value.clone()), other => Err(unexpected_literal_value(other)), } } #[pyo3(name = "getIntervalDayTimeValue")] pub fn interval_day_time_value(&self) -> PyResult> { match self.get_scalar_value()? { ScalarValue::IntervalDayTime(Some(iv)) => { let interval = *iv as u64; let days = (interval >> 32) as i32; let ms = interval as i32; Ok(Some((days, ms))) } ScalarValue::IntervalDayTime(None) => Ok(None), other => Err(unexpected_literal_value(other)), } } #[pyo3(name = "getIntervalMonthDayNanoValue")] pub fn interval_month_day_nano_value(&self) -> PyResult> { match self.get_scalar_value()? { ScalarValue::IntervalMonthDayNano(Some(iv)) => { let interval = *iv as u128; let months = (interval >> 32) as i32; let days = (interval >> 64) as i32; let ns = interval as i64; Ok(Some((months, days, ns))) } ScalarValue::IntervalMonthDayNano(None) => Ok(None), other => Err(unexpected_literal_value(other)), } } #[pyo3(name = "isNegated")] pub fn is_negated(&self) -> PyResult { match &self.expr { Expr::Between(Between { negated, .. }) | Expr::Exists(Exists { negated, .. }) | Expr::InList(InList { negated, .. }) | Expr::InSubquery(InSubquery { negated, .. }) => Ok(*negated), _ => Err(py_type_err(format!( "unknown Expr type {:?} encountered", &self.expr ))), } } #[pyo3(name = "isDistinctAgg")] pub fn is_distinct_aggregation(&self) -> PyResult { // TODO refactor to avoid duplication match &self.expr { Expr::AggregateFunction(funct) => Ok(funct.distinct), Expr::AggregateUDF { .. } => Ok(false), Expr::Alias(Alias { expr, .. }) => match expr.as_ref() { Expr::AggregateFunction(funct) => Ok(funct.distinct), Expr::AggregateUDF { .. } => Ok(false), _ => Err(py_type_err( "isDistinctAgg() - Non-aggregate expression encountered", )), }, _ => Err(py_type_err( "getFilterExpr() - Non-aggregate expression encountered", )), } } /// Returns if a sort expressions is an ascending sort #[pyo3(name = "isSortAscending")] pub fn is_sort_ascending(&self) -> PyResult { match &self.expr { Expr::Sort(Sort { asc, .. }) => Ok(*asc), _ => Err(py_type_err(format!( "Provided Expr {:?} is not a sort type", &self.expr ))), } } /// Returns if nulls should be placed first in a sort expression #[pyo3(name = "isSortNullsFirst")] pub fn is_sort_nulls_first(&self) -> PyResult { match &self.expr { Expr::Sort(Sort { nulls_first, .. }) => Ok(*nulls_first), _ => Err(py_type_err(format!( "Provided Expr {:?} is not a sort type", &self.expr ))), } } /// Returns the escape char for like/ilike/similar to expr variants #[pyo3(name = "getEscapeChar")] pub fn get_escape_char(&self) -> PyResult> { match &self.expr { Expr::Like(Like { escape_char, .. }) | Expr::SimilarTo(Like { escape_char, .. }) => { Ok(*escape_char) } _ => Err(py_type_err(format!( "Provided Expr {:?} not one of Like/ILike/SimilarTo", &self.expr ))), } } } impl PyExpr { /// Get the scalar value represented by this literal expression, returning an error /// if this is not a literal expression fn get_scalar_value(&self) -> Result<&ScalarValue> { match &self.expr { Expr::Literal(v) => Ok(v), _ => Err(DaskPlannerError::Internal( "get_scalar_value() called on non-literal expression".to_string(), )), } } } fn unexpected_literal_value(value: &ScalarValue) -> PyErr { DaskPlannerError::Internal(format!("getValue() - Unexpected value: {value}")).into() } fn get_expr_name(expr: &Expr) -> Result { match expr { Expr::Alias(Alias { expr, .. }) => get_expr_name(expr), Expr::Wildcard => { // 'Wildcard' means any and all columns. We get the first valid column name here Ok("*".to_owned()) } _ => Ok(expr.canonical_name()), } } /// Create a [DFField] representing an [Expr], given an input [LogicalPlan] to resolve against pub fn expr_to_field(expr: &Expr, input_plan: &LogicalPlan) -> Result { match expr { Expr::Sort(Sort { expr, .. }) => { // DataFusion does not support create_name for sort expressions (since they never // appear in projections) so we just delegate to the contained expression instead expr_to_field(expr, input_plan) } Expr::Wildcard => { // Any column will do. We use the first column to keep things consistent Ok(input_plan.schema().field(0).clone()) } Expr::InSubquery(insubquery) => expr_to_field(&insubquery.expr, input_plan), _ => { let fields = exprlist_to_fields(&[expr.clone()], input_plan).map_err(DaskPlannerError::from)?; Ok(fields[0].clone()) } } } #[cfg(test)] mod test { use datafusion_python::{ datafusion_common::{Column, ScalarValue}, datafusion_expr::Expr, }; use crate::{error::Result, expression::PyExpr}; #[test] fn get_value_u32() -> Result<()> { test_get_value(ScalarValue::UInt32(None))?; test_get_value(ScalarValue::UInt32(Some(123))) } #[test] fn get_value_utf8() -> Result<()> { test_get_value(ScalarValue::Utf8(None))?; test_get_value(ScalarValue::Utf8(Some("hello".to_string()))) } #[test] fn get_value_non_literal() -> Result<()> { let expr = PyExpr::from(Expr::Column(Column::from_qualified_name("a.b")), None); let error = expr .get_scalar_value() .expect_err("cannot get scalar value from column"); assert_eq!( "Internal(\"get_scalar_value() called on non-literal expression\")", &format!("{:?}", error) ); Ok(()) } fn test_get_value(value: ScalarValue) -> Result<()> { let expr = PyExpr::from(Expr::Literal(value.clone()), None); assert_eq!(&value, expr.get_scalar_value()?); Ok(()) } } ================================================ FILE: src/lib.rs ================================================ use log::debug; use pyo3::prelude::*; mod dialect; mod error; mod expression; mod parser; mod sql; /// Low-level DataFusion internal package. /// /// The higher-level public API is defined in pure python files under the /// dask_planner directory. #[pymodule] fn _datafusion_lib(py: Python, m: &PyModule) -> PyResult<()> { // Initialize the global Python logger instance pyo3_log::init(); // Register the python classes m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; // Exceptions m.add( "DFParsingException", py.get_type::(), )?; m.add( "DFOptimizationException", py.get_type::(), )?; debug!("dask_sql native library loaded"); Ok(()) } ================================================ FILE: src/parser.rs ================================================ //! SQL Parser //! //! Declares a SQL parser based on sqlparser that handles custom formats that we need. use std::collections::VecDeque; use datafusion_python::datafusion_sql::sqlparser::{ ast::{Expr, Ident, SelectItem, Statement as SQLStatement, UnaryOperator, Value}, dialect::{keywords::Keyword, Dialect}, parser::{Parser, ParserError}, tokenizer::{Token, TokenWithLocation, Tokenizer}, }; use pyo3::prelude::*; use crate::{ dialect::DaskDialect, sql::{exceptions::py_type_err, parser_utils::DaskParserUtils, types::SqlTypeName}, }; macro_rules! parser_err { ($MSG:expr) => { Err(ParserError::ParserError($MSG.to_string())) }; } #[derive(Debug, Clone, PartialEq, Eq)] pub enum CustomExpr { Map(Vec), Multiset(Vec), Nested(Vec<(String, PySqlArg)>), } #[pyclass(name = "SqlArg", module = "dask_sql")] #[derive(Debug, Clone, PartialEq, Eq)] pub struct PySqlArg { expr: Option, custom: Option, } impl PySqlArg { pub fn new(expr: Option, custom: Option) -> Self { Self { expr, custom } } fn expected(&self, expected: &str) -> PyResult { Err(match &self.custom { Some(custom_expr) => { py_type_err(format!("Expected {expected}, found: {custom_expr:?}")) } None => match &self.expr { Some(expr) => py_type_err(format!("Expected {expected}, found: {expr:?}")), None => py_type_err("PySqlArg must be either a standard or custom AST expression"), }, }) } } #[pymethods] impl PySqlArg { #[pyo3(name = "isCollection")] pub fn is_collection(&self) -> PyResult { Ok(match &self.custom { Some(custom_expr) => !matches!(custom_expr, CustomExpr::Nested(_)), None => match &self.expr { Some(expr) => matches!(expr, Expr::Array(_)), None => return self.expected(""), }, }) } #[pyo3(name = "isKwargs")] pub fn is_kwargs(&self) -> PyResult { Ok(matches!(&self.custom, Some(CustomExpr::Nested(_)))) } #[pyo3(name = "getOperandList")] pub fn get_operand_list(&self) -> PyResult> { Ok(match &self.custom { Some(custom_expr) => match custom_expr { CustomExpr::Map(exprs) | CustomExpr::Multiset(exprs) => exprs .iter() .map(|e| PySqlArg::new(Some(e.clone()), None)) .collect(), _ => vec![], }, None => match &self.expr { Some(expr) => match expr { Expr::Array(array) => array .elem .iter() .map(|e| PySqlArg::new(Some(e.clone()), None)) .collect(), _ => vec![], }, None => return self.expected(""), }, }) } #[pyo3(name = "getKwargs")] pub fn get_kwargs(&self) -> PyResult> { Ok(match &self.custom { Some(CustomExpr::Nested(kwargs)) => kwargs.clone(), _ => vec![], }) } #[pyo3(name = "getSqlType")] pub fn get_sql_type(&self) -> PyResult { Ok(match &self.custom { Some(custom_expr) => match custom_expr { CustomExpr::Map(_) => SqlTypeName::MAP, CustomExpr::Multiset(_) => SqlTypeName::MULTISET, _ => return self.expected("Map or multiset"), }, None => match &self.expr { Some(Expr::Array(_)) => SqlTypeName::ARRAY, Some(Expr::Identifier(Ident { .. })) => SqlTypeName::VARCHAR, Some(Expr::Value(scalar)) => match scalar { Value::Boolean(_) => SqlTypeName::BOOLEAN, Value::Number(_, false) => SqlTypeName::BIGINT, Value::SingleQuotedString(_) => SqlTypeName::VARCHAR, _ => return self.expected("Boolean, integer, float, or single-quoted string"), }, Some(Expr::UnaryOp { op: UnaryOperator::Minus, expr, }) => match &**expr { Expr::Value(Value::Number(_, false)) => SqlTypeName::BIGINT, _ => return self.expected("Integer or float"), }, Some(_) => return self.expected("Array, identifier, or scalar"), None => return self.expected(""), }, }) } #[pyo3(name = "getSqlValue")] pub fn get_sql_value(&self) -> PyResult { Ok(match &self.custom { None => match &self.expr { Some(Expr::Identifier(Ident { value, .. })) => value.to_string(), Some(Expr::Value(scalar)) => match scalar { Value::Boolean(true) => "1".to_string(), Value::Boolean(false) => "".to_string(), Value::SingleQuotedString(string) => string.to_string(), Value::Number(value, false) => value.to_string(), _ => return self.expected("Boolean, integer, float, or single-quoted string"), }, Some(Expr::UnaryOp { op: UnaryOperator::Minus, expr, }) => match &**expr { Expr::Value(Value::Number(value, false)) => format!("-{value}"), _ => return self.expected("Integer or float"), }, _ => return self.expected("Array, identifier, or scalar"), }, _ => return self.expected("Standard sqlparser AST expression"), }) } } /// Dask-SQL extension DDL for `CREATE MODEL` #[derive(Debug, Clone, PartialEq, Eq)] pub struct CreateModel { /// schema and model name, i.e. 'schema_name.model_name' pub schema_name: Option, pub model_name: String, /// input query pub select: DaskStatement, /// whether or not IF NOT EXISTS was specified pub if_not_exists: bool, /// whether or not OR REPLACE was specified pub or_replace: bool, /// kwargs specified in WITH pub with_options: Vec<(String, PySqlArg)>, } /// Dask-SQL extension DDL for `CREATE EXPERIMENT` #[derive(Debug, Clone, PartialEq, Eq)] pub struct CreateExperiment { /// schema and experiment name, i.e. 'schema_name.experiment_name' pub schema_name: Option, pub experiment_name: String, /// input query pub select: DaskStatement, /// whether or not IF NOT EXISTS was specified pub if_not_exists: bool, /// whether or not OR REPLACE was specified pub or_replace: bool, /// kwargs specified in WITH pub with_options: Vec<(String, PySqlArg)>, } /// Dask-SQL extension DDL for `PREDICT` #[derive(Debug, Clone, PartialEq, Eq)] pub struct PredictModel { /// schema and model name, i.e. 'schema_name.model_name' pub schema_name: Option, pub model_name: String, /// input query pub select: DaskStatement, } /// Dask-SQL extension DDL for `CREATE SCHEMA` #[derive(Debug, Clone, PartialEq, Eq)] pub struct CreateCatalogSchema { /// schema name pub schema_name: String, /// whether or not IF NOT EXISTS was specified pub if_not_exists: bool, /// whether or not OR REPLACE was specified pub or_replace: bool, } /// Dask-SQL extension DDL for `CREATE TABLE ... WITH` #[derive(Debug, Clone, PartialEq, Eq)] pub struct CreateTable { /// schema and table name, i.e. 'schema_name.table_name' pub schema_name: Option, pub table_name: String, /// whether or not IF NOT EXISTS was specified pub if_not_exists: bool, /// whether or not OR REPLACE was specified pub or_replace: bool, /// kwargs specified in WITH pub with_options: Vec<(String, PySqlArg)>, } /// Dask-SQL extension DDL for `DROP MODEL` #[derive(Debug, Clone, PartialEq, Eq)] pub struct DropModel { /// schema and model name, i.e. 'schema_name.table_name' pub schema_name: Option, pub model_name: String, /// whether or not IF NOT EXISTS was specified pub if_exists: bool, } /// Dask-SQL extension DDL for `EXPORT MODEL` #[derive(Debug, Clone, PartialEq, Eq)] pub struct ExportModel { /// schema and model name, i.e. 'schema_name.table_name' pub schema_name: Option, pub model_name: String, /// kwargs specified in WITH pub with_options: Vec<(String, PySqlArg)>, } /// Dask-SQL extension DDL for `DESCRIBE MODEL` #[derive(Debug, Clone, PartialEq, Eq)] pub struct DescribeModel { /// schema and model name, i.e. 'schema_name.table_name' pub schema_name: Option, pub model_name: String, } /// Dask-SQL extension DDL for `SHOW SCHEMAS` #[derive(Debug, Clone, PartialEq, Eq)] pub struct ShowSchemas { /// optional catalog name pub catalog_name: Option, /// optional LIKE identifier pub like: Option, } /// Dask-SQL extension DDL for `SHOW TABLES FROM` #[derive(Debug, Clone, PartialEq, Eq)] pub struct ShowTables { /// catalog and schema name, i.e. 'catalog_name.schema_name' pub catalog_name: Option, pub schema_name: Option, } /// Dask-SQL extension DDL for `SHOW COLUMNS FROM` #[derive(Debug, Clone, PartialEq, Eq)] pub struct ShowColumns { /// schema and table name, i.e. 'schema_name.table_name' pub schema_name: Option, pub table_name: String, } /// Dask-SQL extension DDL for `SHOW MODELS` #[derive(Debug, Clone, PartialEq, Eq)] pub struct ShowModels { pub schema_name: Option, } /// Dask-SQL extension DDL for `USE SCHEMA` #[derive(Debug, Clone, PartialEq, Eq)] pub struct DropSchema { /// schema name pub schema_name: String, /// whether or not IF NOT EXISTS was specified pub if_exists: bool, } /// Dask-SQL extension DDL for `USE SCHEMA` #[derive(Debug, Clone, PartialEq, Eq)] pub struct UseSchema { /// schema name pub schema_name: String, } /// Dask-SQL extension DDL for `ANALYZE TABLE` #[derive(Debug, Clone, PartialEq, Eq)] pub struct AnalyzeTable { /// schema and table name, i.e. 'schema_name.table_name' pub schema_name: Option, pub table_name: String, /// columns to analyze in specified table pub columns: Vec, } /// Dask-SQL extension DDL for `ALTER TABLE` #[derive(Debug, Clone, PartialEq, Eq)] pub struct AlterTable { pub old_table_name: String, pub new_table_name: String, pub schema_name: Option, pub if_exists: bool, } /// Dask-SQL extension DDL for `ALTER SCHEMA` #[derive(Debug, Clone, PartialEq, Eq)] pub struct AlterSchema { pub old_schema_name: String, pub new_schema_name: String, } /// Dask-SQL Statement representations. /// /// Tokens parsed by `DaskParser` are converted into these values. #[derive(Debug, Clone, PartialEq, Eq)] pub enum DaskStatement { /// ANSI SQL AST node Statement(Box), /// Extension: `CREATE MODEL` CreateModel(Box), /// Extension: `CREATE EXPERIMENT` CreateExperiment(Box), /// Extension: `CREATE SCHEMA` CreateCatalogSchema(Box), /// Extension: `CREATE TABLE` CreateTable(Box), /// Extension: `DROP MODEL` DropModel(Box), /// Extension: `EXPORT MODEL` ExportModel(Box), /// Extension: `DESCRIBE MODEL` DescribeModel(Box), /// Extension: `PREDICT` PredictModel(Box), // Extension: `SHOW SCHEMAS` ShowSchemas(Box), // Extension: `SHOW TABLES FROM` ShowTables(Box), // Extension: `SHOW COLUMNS FROM` ShowColumns(Box), // Extension: `SHOW COLUMNS FROM` ShowModels(Box), // Exntension: `DROP SCHEMA` DropSchema(Box), // Extension: `USE SCHEMA` UseSchema(Box), // Extension: `ANALYZE TABLE` AnalyzeTable(Box), // Extension: `ALTER TABLE` AlterTable(Box), // Extension: `ALTER SCHEMA` AlterSchema(Box), } /// SQL Parser pub struct DaskParser<'a> { parser: Parser<'a>, } impl<'a> DaskParser<'a> { #[allow(dead_code)] /// Parse the specified tokens pub fn new(sql: &str) -> Result { let dialect = &DaskDialect {}; DaskParser::new_with_dialect(sql, dialect) } /// Parse the specified tokens with dialect pub fn new_with_dialect(sql: &str, dialect: &'a dyn Dialect) -> Result { let mut tokenizer = Tokenizer::new(dialect, sql); let tokens = tokenizer.tokenize()?; Ok(DaskParser { parser: Parser::new(dialect).with_tokens(tokens), }) } #[allow(dead_code)] /// Parse a SQL statement and produce a set of statements with dialect pub fn parse_sql(sql: &str) -> Result, ParserError> { let dialect = &DaskDialect {}; DaskParser::parse_sql_with_dialect(sql, dialect) } /// Parse a SQL statement and produce a set of statements pub fn parse_sql_with_dialect( sql: &str, dialect: &dyn Dialect, ) -> Result, ParserError> { let mut parser = DaskParser::new_with_dialect(sql, dialect)?; let mut stmts = VecDeque::new(); let mut expecting_statement_delimiter = false; loop { // ignore empty statements (between successive statement delimiters) while parser.parser.consume_token(&Token::SemiColon) { expecting_statement_delimiter = false; } if parser.parser.peek_token() == Token::EOF { break; } if expecting_statement_delimiter { return parser.expected("end of statement", parser.parser.peek_token()); } let statement = parser.parse_statement()?; stmts.push_back(statement); expecting_statement_delimiter = true; } Ok(stmts) } /// Report unexpected token fn expected(&self, expected: &str, found: TokenWithLocation) -> Result { parser_err!(format!( "Expected {}, found: {} at line {} column {}", expected, found.token, found.location.line, found.location.column )) } /// Parse a new expression pub fn parse_statement(&mut self) -> Result { match self.parser.peek_token().token { Token::Word(w) => { match w.keyword { Keyword::CREATE => { // move one token forward self.parser.next_token(); // use custom parsing self.parse_create() } Keyword::DROP => { // move one token forward self.parser.next_token(); // use custom parsing self.parse_drop() } Keyword::SELECT => { // Check for PREDICT token in statement let mut cnt = 1; loop { match self.parser.next_token().token { Token::Word(w) => { match w.value.to_lowercase().as_str() { "predict" => { return self.parse_predict_model(); } _ => { // Keep looking for PREDICT cnt += 1; continue; } } } Token::EOF => { break; } _ => { // Keep looking for PREDICT cnt += 1; continue; } } } // Reset the parser back to where we started for _ in 0..cnt { self.parser.prev_token(); } // use the native parser Ok(DaskStatement::Statement(Box::from( self.parser.parse_statement()?, ))) } Keyword::SHOW => { // move one token forward self.parser.next_token(); // use custom parsing self.parse_show() } Keyword::DESCRIBE => { // move one token forwrd self.parser.next_token(); // use custom parsing self.parse_describe() } Keyword::USE => { // move one token forwrd self.parser.next_token(); // use custom parsing self.parse_use() } Keyword::ANALYZE => { // move one token foward self.parser.next_token(); self.parse_analyze() } Keyword::ALTER => { // move one token forward self.parser.next_token(); self.parse_alter() } _ => { match w.value.to_lowercase().as_str() { "export" => { // move one token forwrd self.parser.next_token(); // use custom parsing self.parse_export_model() } _ => { // use the native parser Ok(DaskStatement::Statement(Box::from( self.parser.parse_statement()?, ))) } } } } } _ => { // use the native parser Ok(DaskStatement::Statement(Box::from( self.parser.parse_statement()?, ))) } } } /// Parse a SQL CREATE statement pub fn parse_create(&mut self) -> Result { let or_replace = self.parser.parse_keywords(&[Keyword::OR, Keyword::REPLACE]); match self.parser.peek_token().token { Token::Word(w) => { match w.value.to_lowercase().as_str() { "model" => { // move one token forward self.parser.next_token(); let if_not_exists = self.parser.parse_keywords(&[ Keyword::IF, Keyword::NOT, Keyword::EXISTS, ]); // use custom parsing self.parse_create_model(if_not_exists, or_replace) } "experiment" => { // move one token forward self.parser.next_token(); let if_not_exists = self.parser.parse_keywords(&[ Keyword::IF, Keyword::NOT, Keyword::EXISTS, ]); // use custom parsing self.parse_create_experiment(if_not_exists, or_replace) } "schema" => { // move one token forward self.parser.next_token(); let if_not_exists = self.parser.parse_keywords(&[ Keyword::IF, Keyword::NOT, Keyword::EXISTS, ]); // use custom parsing self.parse_create_schema(if_not_exists, or_replace) } "table" => { // move one token forward self.parser.next_token(); // use custom parsing self.parse_create_table(true, or_replace) } "view" => { // move one token forward self.parser.next_token(); // use custom parsing self.parse_create_table(false, or_replace) } _ => { if or_replace { // Go back two tokens if OR REPLACE was consumed self.parser.prev_token(); self.parser.prev_token(); } // use the native parser Ok(DaskStatement::Statement(Box::from( self.parser.parse_create()?, ))) } } } _ => { if or_replace { // Go back two tokens if OR REPLACE was consumed self.parser.prev_token(); self.parser.prev_token(); } // use the native parser Ok(DaskStatement::Statement(Box::from( self.parser.parse_create()?, ))) } } } /// Parse a SQL DROP statement pub fn parse_drop(&mut self) -> Result { match self.parser.peek_token().token { Token::Word(w) => { match w.value.to_lowercase().as_str() { "model" => { // move one token forward self.parser.next_token(); // use custom parsing self.parse_drop_model() } "schema" => { // move one token forward self.parser.next_token(); // use custom parsing let if_exists = self.parser.parse_keywords(&[Keyword::IF, Keyword::EXISTS]); let schema_name = self.parser.parse_identifier()?; let drop_schema = DropSchema { schema_name: schema_name.value, if_exists, }; Ok(DaskStatement::DropSchema(Box::new(drop_schema))) } _ => { // use the native parser Ok(DaskStatement::Statement(Box::from( self.parser.parse_drop()?, ))) } } } _ => { // use the native parser Ok(DaskStatement::Statement(Box::from( self.parser.parse_drop()?, ))) } } } /// Parse a SQL SHOW statement pub fn parse_show(&mut self) -> Result { match self.parser.peek_token().token { Token::Word(w) => { match w.value.to_lowercase().as_str() { "schemas" => { // move one token forward self.parser.next_token(); // use custom parsing self.parse_show_schemas() } "tables" => { // move one token forward self.parser.next_token(); // If non ansi ... `FROM {schema_name}` is present custom parse // otherwise use sqlparser-rs match self.parser.peek_token().token { Token::Word(w) => { match w.value.to_lowercase().as_str() { "from" => { // move one token forward self.parser.next_token(); // use custom parsing self.parse_show_tables() } _ => { self.parser.prev_token(); // use the native parser Ok(DaskStatement::Statement(Box::from( self.parser.parse_show()?, ))) } } } _ => self.parse_show_tables(), } } "columns" => { self.parser.next_token(); // use custom parsing self.parse_show_columns() } "models" => { self.parser.next_token(); // use custom parsing self.parse_show_models() } _ => { // use the native parser Ok(DaskStatement::Statement(Box::from( self.parser.parse_show()?, ))) } } } _ => { // use the native parser Ok(DaskStatement::Statement(Box::from( self.parser.parse_show()?, ))) } } } /// Parse a SQL DESCRIBE statement pub fn parse_describe(&mut self) -> Result { match self.parser.peek_token().token { Token::Word(w) => { match w.value.to_lowercase().as_str() { "model" => { self.parser.next_token(); // use custom parsing self.parse_describe_model() } _ => { // use the native parser Ok(DaskStatement::Statement(Box::from( self.parser.parse_show()?, ))) } } } _ => { // use the native parser Ok(DaskStatement::Statement(Box::from( self.parser.parse_show()?, ))) } } } /// Parse a SQL USE SCHEMA statement pub fn parse_use(&mut self) -> Result { match self.parser.peek_token().token { Token::Word(w) => { match w.value.to_lowercase().as_str() { "schema" => { // move one token forward self.parser.next_token(); // use custom parsing let schema_name = self.parser.parse_identifier()?; let use_schema = UseSchema { schema_name: schema_name.value, }; Ok(DaskStatement::UseSchema(Box::new(use_schema))) } _ => Ok(DaskStatement::Statement(Box::from( self.parser.parse_show()?, ))), } } _ => Ok(DaskStatement::Statement(Box::from( self.parser.parse_show()?, ))), } } /// Parse a SQL ANALYZE statement pub fn parse_analyze(&mut self) -> Result { match self.parser.peek_token().token { Token::Word(w) => { match w.value.to_lowercase().as_str() { "table" => { // move one token forward self.parser.next_token(); // use custom parsing self.parse_analyze_table() } _ => { // use the native parser Ok(DaskStatement::Statement(Box::from( self.parser.parse_analyze()?, ))) } } } _ => { // use the native parser Ok(DaskStatement::Statement(Box::from( self.parser.parse_analyze()?, ))) } } } /// Parse a SQL ALTER statement pub fn parse_alter(&mut self) -> Result { match self.parser.peek_token().token { Token::Word(w) => { match w.keyword { Keyword::TABLE => { self.parser.next_token(); self.parse_alter_table() } Keyword::SCHEMA => { self.parser.next_token(); self.parse_alter_schema() } _ => { // use the native parser Ok(DaskStatement::Statement(Box::from( self.parser.parse_alter()?, ))) } } } _ => { // use the native parser Ok(DaskStatement::Statement(Box::from( self.parser.parse_alter()?, ))) } } } /// Parse a SQL PREDICT statement pub fn parse_predict_model(&mut self) -> Result { // PREDICT( // MODEL model_name, // SQLStatement // ) self.parser.expect_token(&Token::LParen)?; let is_model = match self.parser.next_token().token { Token::Word(w) => matches!(w.value.to_lowercase().as_str(), "model"), _ => false, }; if !is_model { return Err(ParserError::ParserError( "parse_predict_model: Expected `MODEL`".to_string(), )); } let (schema_name, model_name) = DaskParserUtils::elements_from_object_name(&self.parser.parse_object_name()?)?; self.parser.expect_token(&Token::Comma)?; // Limit our input to ANALYZE, DESCRIBE, SELECT, SHOW statements // TODO: find a more sophisticated way to allow any statement that would return a table self.parser.expect_one_of_keywords(&[ Keyword::SELECT, Keyword::DESCRIBE, Keyword::SHOW, Keyword::ANALYZE, ])?; self.parser.prev_token(); let select = self.parse_statement()?; self.parser.expect_token(&Token::RParen)?; let predict = PredictModel { schema_name, model_name, select, }; Ok(DaskStatement::PredictModel(Box::new(predict))) } /// Parse Dask-SQL CREATE MODEL statement fn parse_create_model( &mut self, if_not_exists: bool, or_replace: bool, ) -> Result { // Parse schema and model name let (schema_name, model_name) = DaskParserUtils::elements_from_object_name(&self.parser.parse_object_name()?)?; // Parse WITH options self.parser.expect_keyword(Keyword::WITH)?; self.parser.expect_token(&Token::LParen)?; let with_options = self.parse_comma_separated(DaskParser::parse_key_value_pair)?; self.parser.expect_token(&Token::RParen)?; // Parse the nested query statement self.parser.expect_keyword(Keyword::AS)?; self.parser.expect_token(&Token::LParen)?; // Limit our input to ANALYZE, DESCRIBE, SELECT, SHOW statements // TODO: find a more sophisticated way to allow any statement that would return a table self.parser.expect_one_of_keywords(&[ Keyword::SELECT, Keyword::DESCRIBE, Keyword::SHOW, Keyword::ANALYZE, ])?; self.parser.prev_token(); let select = self.parse_statement()?; self.parser.expect_token(&Token::RParen)?; let create = CreateModel { schema_name, model_name, select, if_not_exists, or_replace, with_options, }; Ok(DaskStatement::CreateModel(Box::new(create))) } // copied from sqlparser crate and adapted to work with DaskParser fn parse_comma_separated(&mut self, mut f: F) -> Result, ParserError> where F: FnMut(&mut DaskParser<'a>) -> Result, { let mut values = vec![]; loop { values.push(f(self)?); if !self.parser.consume_token(&Token::Comma) { break; } } Ok(values) } fn parse_key_value_pair(&mut self) -> Result<(String, PySqlArg), ParserError> { let key = self.parser.parse_identifier()?; self.parser.expect_token(&Token::Eq)?; match self.parser.next_token().token { Token::LParen => { let key_value_pairs = self.parse_comma_separated(DaskParser::parse_key_value_pair)?; self.parser.expect_token(&Token::RParen)?; Ok(( key.value, PySqlArg::new(None, Some(CustomExpr::Nested(key_value_pairs))), )) } Token::Word(w) if w.value.to_lowercase().as_str() == "map" => { // TODO this does not support map or multiset expressions within the map self.parser.expect_token(&Token::LBracket)?; let values = self.parser.parse_comma_separated(Parser::parse_expr)?; self.parser.expect_token(&Token::RBracket)?; Ok(( key.value, PySqlArg::new(None, Some(CustomExpr::Map(values))), )) } Token::Word(w) if w.value.to_lowercase().as_str() == "multiset" => { // TODO this does not support map or multiset expressions within the multiset self.parser.expect_token(&Token::LBracket)?; let values = self.parser.parse_comma_separated(Parser::parse_expr)?; self.parser.expect_token(&Token::RBracket)?; Ok(( key.value, PySqlArg::new(None, Some(CustomExpr::Multiset(values))), )) } _ => { self.parser.prev_token(); Ok(( key.value, PySqlArg::new(Some(self.parser.parse_expr()?), None), )) } } } /// Parse Dask-SQL CREATE EXPERIMENT statement fn parse_create_experiment( &mut self, if_not_exists: bool, or_replace: bool, ) -> Result { // Parse schema and model name let (schema_name, experiment_name) = DaskParserUtils::elements_from_object_name(&self.parser.parse_object_name()?)?; // Parse WITH options self.parser.expect_keyword(Keyword::WITH)?; self.parser.expect_token(&Token::LParen)?; let with_options = self.parse_comma_separated(DaskParser::parse_key_value_pair)?; self.parser.expect_token(&Token::RParen)?; // Parse the nested query statement self.parser.expect_keyword(Keyword::AS)?; self.parser.expect_token(&Token::LParen)?; // Limit our input to ANALYZE, DESCRIBE, SELECT, SHOW statements // TODO: find a more sophisticated way to allow any statement that would return a table self.parser.expect_one_of_keywords(&[ Keyword::SELECT, Keyword::DESCRIBE, Keyword::SHOW, Keyword::ANALYZE, ])?; self.parser.prev_token(); let select = self.parse_statement()?; self.parser.expect_token(&Token::RParen)?; let create = CreateExperiment { schema_name, experiment_name, select, if_not_exists, or_replace, with_options, }; Ok(DaskStatement::CreateExperiment(Box::new(create))) } /// Parse Dask-SQL CREATE {IF NOT EXISTS | OR REPLACE} SCHEMA ... statement fn parse_create_schema( &mut self, if_not_exists: bool, or_replace: bool, ) -> Result { let schema_name = self.parser.parse_identifier()?.value; let create = CreateCatalogSchema { schema_name, if_not_exists, or_replace, }; Ok(DaskStatement::CreateCatalogSchema(Box::new(create))) } /// Parse Dask-SQL CREATE [OR REPLACE] TABLE ... statement /// /// # Arguments /// /// * `is_table` - Whether the "table" is a "TABLE" or "VIEW", True if "TABLE" and False otherwise. /// * `or_replace` - True if the "TABLE" or "VIEW" should be replaced and False otherwise fn parse_create_table( &mut self, is_table: bool, or_replace: bool, ) -> Result { // parse [IF NOT EXISTS] `table_name` AS|WITH let if_not_exists = self.parser .parse_keywords(&[Keyword::IF, Keyword::NOT, Keyword::EXISTS]); let _table_name = self.parser.parse_identifier(); let after_name_token = self.parser.peek_token().token; match after_name_token { Token::Word(w) => { match w.value.to_lowercase().as_str() { "as" => { self.parser.prev_token(); if if_not_exists { // Go back three tokens if IF NOT EXISTS was consumed, native parser consumes these tokens as well self.parser.prev_token(); self.parser.prev_token(); self.parser.prev_token(); } // True if TABLE and False if VIEW if is_table { Ok(DaskStatement::Statement(Box::from( self.parser .parse_create_table(or_replace, false, None, false)?, ))) } else { self.parser.prev_token(); Ok(DaskStatement::Statement(Box::from( self.parser.parse_create_view(or_replace)?, ))) } } "with" => { // `table_name` has been parsed at this point but is needed, reset consumption self.parser.prev_token(); // Parse schema and table name let (schema_name, table_name) = DaskParserUtils::elements_from_object_name( &self.parser.parse_object_name()?, )?; // Parse WITH options self.parser.expect_keyword(Keyword::WITH)?; self.parser.expect_token(&Token::LParen)?; let with_options = self.parse_comma_separated(DaskParser::parse_key_value_pair)?; self.parser.expect_token(&Token::RParen)?; let create = CreateTable { schema_name, table_name, if_not_exists, or_replace, with_options, }; Ok(DaskStatement::CreateTable(Box::new(create))) } _ => self.expected("'as' or 'with'", self.parser.peek_token()), } } _ => { self.parser.prev_token(); if if_not_exists { // Go back three tokens if IF NOT EXISTS was consumed self.parser.prev_token(); self.parser.prev_token(); self.parser.prev_token(); } // use the native parser Ok(DaskStatement::Statement(Box::from( self.parser .parse_create_table(or_replace, false, None, false)?, ))) } } } /// Parse Dask-SQL EXPORT MODEL statement fn parse_export_model(&mut self) -> Result { let is_model = match self.parser.next_token().token { Token::Word(w) => matches!(w.value.to_lowercase().as_str(), "model"), _ => false, }; if !is_model { return Err(ParserError::ParserError( "parse_export_model: Expected `MODEL`".to_string(), )); } // Parse schema and model name let (schema_name, model_name) = DaskParserUtils::elements_from_object_name(&self.parser.parse_object_name()?)?; // Parse WITH options self.parser.expect_keyword(Keyword::WITH)?; self.parser.expect_token(&Token::LParen)?; let with_options = self.parse_comma_separated(DaskParser::parse_key_value_pair)?; self.parser.expect_token(&Token::RParen)?; let export = ExportModel { schema_name, model_name, with_options, }; Ok(DaskStatement::ExportModel(Box::new(export))) } /// Parse Dask-SQL DROP MODEL statement fn parse_drop_model(&mut self) -> Result { let if_exists = self.parser.parse_keywords(&[Keyword::IF, Keyword::EXISTS]); // Parse schema and model name let (schema_name, model_name) = DaskParserUtils::elements_from_object_name(&self.parser.parse_object_name()?)?; let drop = DropModel { schema_name, model_name, if_exists, }; Ok(DaskStatement::DropModel(Box::new(drop))) } /// Parse Dask-SQL DESRIBE MODEL statement fn parse_describe_model(&mut self) -> Result { // Parse schema and model name let (schema_name, model_name) = DaskParserUtils::elements_from_object_name(&self.parser.parse_object_name()?)?; let describe = DescribeModel { schema_name, model_name, }; Ok(DaskStatement::DescribeModel(Box::new(describe))) } /// Parse Dask-SQL SHOW SCHEMAS statement fn parse_show_schemas(&mut self) -> Result { // parse optional `FROM` clause let catalog_name = match self.parser.peek_token().token { Token::Word(w) => { match w.keyword { Keyword::FROM => { // move one token forward self.parser.next_token(); // use custom parsing Some(self.parser.parse_identifier()?.value) } _ => None, } } _ => None, }; // parse optional `LIKE` clause let like = match self.parser.peek_token().token { Token::Word(w) => { match w.keyword { Keyword::LIKE => { // move one token forward self.parser.next_token(); // use custom parsing Some(self.parser.parse_identifier()?.value) } _ => None, } } _ => None, }; Ok(DaskStatement::ShowSchemas(Box::new(ShowSchemas { catalog_name, like, }))) } /// Parse Dask-SQL SHOW TABLES [FROM] statement fn parse_show_tables(&mut self) -> Result { if let Ok(obj_name) = &self.parser.parse_object_name() { let (catalog_name, schema_name) = DaskParserUtils::elements_from_object_name(obj_name)?; return Ok(DaskStatement::ShowTables(Box::new(ShowTables { catalog_name, schema_name: Some(schema_name), }))); } Ok(DaskStatement::ShowTables(Box::new(ShowTables { catalog_name: None, schema_name: None, }))) } /// Parse Dask-SQL SHOW COLUMNS FROM fn parse_show_columns(&mut self) -> Result { self.parser.expect_keyword(Keyword::FROM)?; let (schema_name, table_name) = DaskParserUtils::elements_from_object_name(&self.parser.parse_object_name()?)?; Ok(DaskStatement::ShowColumns(Box::new(ShowColumns { schema_name, table_name, }))) } /// Parse Dask-SQL SHOW MODEL [FROM ] fn parse_show_models(&mut self) -> Result { let mut schema_name: Option = None; if !self.parser.consume_token(&Token::EOF) { self.parser.expect_keyword(Keyword::FROM)?; schema_name = Some(self.parser.parse_identifier()?.value); } Ok(DaskStatement::ShowModels(Box::new(ShowModels { schema_name, }))) } /// Parse Dask-SQL ANALYZE TABLE
fn parse_analyze_table(&mut self) -> Result { let obj_name = self.parser.parse_object_name()?; self.parser .expect_keywords(&[Keyword::COMPUTE, Keyword::STATISTICS, Keyword::FOR])?; let (schema_name, table_name) = DaskParserUtils::elements_from_object_name(&obj_name)?; let columns = match self .parser .parse_keywords(&[Keyword::ALL, Keyword::COLUMNS]) { true => vec![], false => { self.parser.expect_keyword(Keyword::COLUMNS)?; let mut values = vec![]; for select in self.parser.parse_projection()? { match select { SelectItem::UnnamedExpr(expr) => match expr { Expr::Identifier(ident) => values.push(ident.value), unexpected => { return parser_err!(format!( "Expected Identifier, found: {unexpected}" )) } }, unexpected => { return parser_err!(format!( "Expected UnnamedExpr, found: {unexpected}" )) } } } values } }; Ok(DaskStatement::AnalyzeTable(Box::new(AnalyzeTable { schema_name, table_name, columns, }))) } fn parse_alter_table(&mut self) -> Result { let if_exists = self.parser.parse_keywords(&[Keyword::IF, Keyword::EXISTS]); // parse fully qualified old table name let (schema_name, old_table_name) = DaskParserUtils::elements_from_object_name(&self.parser.parse_object_name()?)?; self.parser .expect_keywords(&[Keyword::RENAME, Keyword::TO])?; // parse new table name let new_table_name = self.parser.parse_identifier()?.value; Ok(DaskStatement::AlterTable(Box::new(AlterTable { old_table_name, new_table_name, schema_name, if_exists, }))) } fn parse_alter_schema(&mut self) -> Result { // parse old schema name let old_schema_name = self.parser.parse_identifier()?.value; self.parser .expect_keywords(&[Keyword::RENAME, Keyword::TO])?; // parse new schema name let new_schema_name = self.parser.parse_identifier()?.value; Ok(DaskStatement::AlterSchema(Box::new(AlterSchema { old_schema_name, new_schema_name, }))) } } #[cfg(test)] mod test { use crate::parser::{DaskParser, DaskStatement}; #[test] fn timestampadd() { let sql = "SELECT TIMESTAMPADD(YEAR, 2, d) FROM t"; let statements = DaskParser::parse_sql(sql).unwrap(); assert_eq!(1, statements.len()); let actual = format!("{:?}", statements[0]); let expected = "Statement(Query(Query { with: None, body: Select(Select { distinct: None, top: None, projection: [UnnamedExpr(Function(Function { name: ObjectName([Ident { value: \"timestampadd\", quote_style: None }]), args: [Unnamed(Expr(Value(SingleQuotedString(\"YEAR\")))), Unnamed(Expr(Value(Number(\"2\", false)))), Unnamed(Expr(Identifier(Ident { value: \"d\", quote_style: None })))], over: None, distinct: false, special: false, order_by: [] }))], into: None, from: [TableWithJoins { relation: Table { name: ObjectName([Ident { value: \"t\", quote_style: None }]), alias: None, args: None, with_hints: [], version: None, partitions: [] }, joins: [] }], lateral_views: [], selection: None, group_by: Expressions([]), cluster_by: [], distribute_by: [], sort_by: [], having: None, named_window: [], qualify: None }), order_by: [], limit: None, offset: None, fetch: None, locks: [] }))"; assert!(actual.contains(expected)); } #[test] fn to_timestamp() { let sql1 = "SELECT TO_TIMESTAMP(d) FROM t"; let statements1 = DaskParser::parse_sql(sql1).unwrap(); assert_eq!(1, statements1.len()); let actual1 = format!("{:?}", statements1[0]); let expected1 = "Statement(Query(Query { with: None, body: Select(Select { distinct: None, top: None, projection: [UnnamedExpr(Function(Function { name: ObjectName([Ident { value: \"dsql_totimestamp\", quote_style: None }]), args: [Unnamed(Expr(Identifier(Ident { value: \"d\", quote_style: None }))), Unnamed(Expr(Value(SingleQuotedString(\"%Y-%m-%d %H:%M:%S\"))))], over: None, distinct: false, special: false, order_by: [] }))], into: None, from: [TableWithJoins { relation: Table { name: ObjectName([Ident { value: \"t\", quote_style: None }]), alias: None, args: None, with_hints: [], version: None, partitions: [] }, joins: [] }], lateral_views: [], selection: None, group_by: Expressions([]), cluster_by: [], distribute_by: [], sort_by: [], having: None, named_window: [], qualify: None }), order_by: [], limit: None, offset: None, fetch: None, locks: [] }))"; assert!(actual1.contains(expected1)); let sql2 = "SELECT TO_TIMESTAMP(d, \"%d/%m/%Y\") FROM t"; let statements2 = DaskParser::parse_sql(sql2).unwrap(); assert_eq!(1, statements2.len()); let actual2 = format!("{:?}", statements2[0]); let expected2 = "Statement(Query(Query { with: None, body: Select(Select { distinct: None, top: None, projection: [UnnamedExpr(Function(Function { name: ObjectName([Ident { value: \"dsql_totimestamp\", quote_style: None }]), args: [Unnamed(Expr(Identifier(Ident { value: \"d\", quote_style: None }))), Unnamed(Expr(Value(SingleQuotedString(\"\\\"%d/%m/%Y\\\"\"))))], over: None, distinct: false, special: false, order_by: [] }))], into: None, from: [TableWithJoins { relation: Table { name: ObjectName([Ident { value: \"t\", quote_style: None }]), alias: None, args: None, with_hints: [], version: None, partitions: [] }, joins: [] }], lateral_views: [], selection: None, group_by: Expressions([]), cluster_by: [], distribute_by: [], sort_by: [], having: None, named_window: [], qualify: None }), order_by: [], limit: None, offset: None, fetch: None, locks: [] }))"; assert!(actual2.contains(expected2)); } #[test] fn create_model() { let sql = r#"CREATE MODEL my_model WITH ( model_class = 'mock.MagicMock', target_column = 'target', fit_kwargs = ( single_quoted_string = 'hello', double_quoted_string = "hi", integer = -300, float = 23.45, boolean = False, array = ARRAY [ 1, 2 ], dict = MAP [ 'a', 1 ], set = MULTISET [ 1, 1, 2, 3 ] ) ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 )"#; let statements = DaskParser::parse_sql(sql).unwrap(); assert_eq!(1, statements.len()); match &statements[0] { DaskStatement::CreateModel(create_model) => { let expected = "[\ (\"model_class\", PySqlArg { expr: Some(Value(SingleQuotedString(\"mock.MagicMock\"))), custom: None }), \ (\"target_column\", PySqlArg { expr: Some(Value(SingleQuotedString(\"target\"))), custom: None }), \ (\"fit_kwargs\", PySqlArg { expr: None, custom: Some(Nested([\ (\"single_quoted_string\", PySqlArg { expr: Some(Value(SingleQuotedString(\"hello\"))), custom: None }), \ (\"double_quoted_string\", PySqlArg { expr: Some(Identifier(Ident { value: \"hi\", quote_style: Some('\"') })), custom: None }), \ (\"integer\", PySqlArg { expr: Some(UnaryOp { op: Minus, expr: Value(Number(\"300\", false)) }), custom: None }), \ (\"float\", PySqlArg { expr: Some(Value(Number(\"23.45\", false))), custom: None }), \ (\"boolean\", PySqlArg { expr: Some(Value(Boolean(false))), custom: None }), \ (\"array\", PySqlArg { expr: Some(Array(Array { elem: [Value(Number(\"1\", false)), Value(Number(\"2\", false))], named: true })), custom: None }), \ (\"dict\", PySqlArg { expr: None, custom: Some(Map([Value(SingleQuotedString(\"a\")), Value(Number(\"1\", false))])) }), \ (\"set\", PySqlArg { expr: None, custom: Some(Multiset([Value(Number(\"1\", false)), Value(Number(\"1\", false)), Value(Number(\"2\", false)), Value(Number(\"3\", false))])) })\ ])) })\ ]"; assert_eq!(expected, &format!("{:?}", create_model.with_options)); } _ => panic!(), } } } ================================================ FILE: src/sql/column.rs ================================================ use datafusion_python::datafusion_common::Column; use pyo3::prelude::*; #[pyclass(name = "Column", module = "dask_sql", subclass)] #[derive(Debug, Clone)] pub struct PyColumn { /// Original Column instance pub(crate) column: Column, } impl From for Column { fn from(column: PyColumn) -> Column { column.column } } impl From for PyColumn { fn from(column: Column) -> PyColumn { PyColumn { column } } } #[pymethods] impl PyColumn { #[pyo3(name = "getRelation")] pub fn relation(&self) -> String { self.column.relation.clone().unwrap().to_string() } #[pyo3(name = "getName")] pub fn name(&self) -> String { self.column.name.clone() } } ================================================ FILE: src/sql/exceptions.rs ================================================ use std::fmt::Debug; use pyo3::{create_exception, PyErr}; // Identifies exceptions that occur while attempting to generate a `LogicalPlan` from a SQL string create_exception!(rust, ParsingException, pyo3::exceptions::PyException); // Identifies exceptions that occur during attempts to optimization an existing `LogicalPlan` create_exception!(rust, OptimizationException, pyo3::exceptions::PyException); pub fn py_type_err(e: impl Debug) -> PyErr { PyErr::new::(format!("{e:?}")) } pub fn py_runtime_err(e: impl Debug) -> PyErr { PyErr::new::(format!("{e:?}")) } pub fn py_parsing_exp(e: impl Debug) -> PyErr { PyErr::new::(format!("{e:?}")) } pub fn py_optimization_exp(e: impl Debug) -> PyErr { PyErr::new::(format!("{e:?}")) } ================================================ FILE: src/sql/function.rs ================================================ use std::collections::HashMap; use datafusion_python::datafusion::arrow::datatypes::DataType; use pyo3::prelude::*; use super::types::PyDataType; #[pyclass(name = "DaskFunction", module = "dask_sql", subclass)] #[derive(Debug, Clone)] pub struct DaskFunction { #[pyo3(get, set)] pub(crate) name: String, pub(crate) return_types: HashMap, DataType>, pub(crate) aggregation: bool, } impl DaskFunction { pub fn new( function_name: String, input_types: Vec, return_type: PyDataType, aggregation_bool: bool, ) -> Self { let mut func = Self { name: function_name, return_types: HashMap::new(), aggregation: aggregation_bool, }; func.add_type_mapping(input_types, return_type); func } pub fn add_type_mapping(&mut self, input_types: Vec, return_type: PyDataType) { self.return_types.insert( input_types.iter().map(|t| t.clone().into()).collect(), return_type.into(), ); } } ================================================ FILE: src/sql/logical/aggregate.rs ================================================ use datafusion_python::datafusion_expr::{ expr::{AggregateFunction, AggregateUDF, Alias}, logical_plan::{Aggregate, Distinct}, Expr, LogicalPlan, }; use pyo3::prelude::*; use crate::{ expression::{py_expr_list, PyExpr}, sql::exceptions::py_type_err, }; #[pyclass(name = "Aggregate", module = "dask_sql", subclass)] #[derive(Clone)] pub struct PyAggregate { aggregate: Option, distinct: Option, } #[pymethods] impl PyAggregate { /// Determine the PyExprs that should be "Distinct-ed" #[pyo3(name = "getDistinctColumns")] pub fn distinct_columns(&self) -> PyResult> { match &self.distinct { Some(e) => Ok(e.input.schema().field_names()), None => Err(py_type_err( "distinct_columns invoked for non distinct instance", )), } } /// Returns a Vec of the group expressions #[pyo3(name = "getGroupSets")] pub fn group_expressions(&self) -> PyResult> { match &self.aggregate { Some(e) => py_expr_list(&e.input, &e.group_expr), None => Ok(vec![]), } } /// Returns the inner Aggregate Expr(s) #[pyo3(name = "getNamedAggCalls")] pub fn agg_expressions(&self) -> PyResult> { match &self.aggregate { Some(e) => py_expr_list(&e.input, &e.aggr_expr), None => Ok(vec![]), } } #[pyo3(name = "getAggregationFuncName")] pub fn agg_func_name(&self, expr: PyExpr) -> PyResult { _agg_func_name(&expr.expr) } #[pyo3(name = "getArgs")] pub fn aggregation_arguments(&self, expr: PyExpr) -> PyResult> { self._aggregation_arguments(&expr.expr) } #[pyo3(name = "isAggExprDistinct")] pub fn distinct_agg_expr(&self, expr: PyExpr) -> PyResult { _distinct_agg_expr(&expr.expr) } #[pyo3(name = "isDistinctNode")] pub fn distinct_node(&self) -> PyResult { Ok(self.distinct.is_some()) } } impl PyAggregate { fn _aggregation_arguments(&self, expr: &Expr) -> PyResult> { match expr { Expr::Alias(Alias { expr, .. }) => self._aggregation_arguments(expr.as_ref()), Expr::AggregateFunction(AggregateFunction { fun: _, args, .. }) | Expr::AggregateUDF(AggregateUDF { fun: _, args, .. }) => match &self.aggregate { Some(e) => py_expr_list(&e.input, args), None => Ok(vec![]), }, _ => Err(py_type_err( "Encountered a non Aggregate type in aggregation_arguments", )), } } } fn _agg_func_name(expr: &Expr) -> PyResult { match expr { Expr::Alias(Alias { expr, .. }) => _agg_func_name(expr.as_ref()), Expr::AggregateFunction(AggregateFunction { fun, .. }) => Ok(fun.to_string()), Expr::AggregateUDF(AggregateUDF { fun, .. }) => Ok(fun.name.clone()), _ => Err(py_type_err( "Encountered a non Aggregate type in agg_func_name", )), } } fn _distinct_agg_expr(expr: &Expr) -> PyResult { match expr { Expr::Alias(Alias { expr, .. }) => _distinct_agg_expr(expr.as_ref()), Expr::AggregateFunction(AggregateFunction { distinct, .. }) => Ok(*distinct), Expr::AggregateUDF { .. } => { // DataFusion does not support DISTINCT in UDAFs Ok(false) } _ => Err(py_type_err( "Encountered a non Aggregate type in distinct_agg_expr", )), } } impl TryFrom for PyAggregate { type Error = PyErr; fn try_from(logical_plan: LogicalPlan) -> Result { match logical_plan { LogicalPlan::Aggregate(aggregate) => Ok(PyAggregate { aggregate: Some(aggregate), distinct: None, }), LogicalPlan::Distinct(distinct) => Ok(PyAggregate { aggregate: None, distinct: Some(distinct), }), _ => Err(py_type_err("unexpected plan")), } } } ================================================ FILE: src/sql/logical/alter_schema.rs ================================================ use std::{ any::Any, fmt, hash::{Hash, Hasher}, sync::Arc, }; use datafusion_python::{ datafusion_common::{DFSchema, DFSchemaRef}, datafusion_expr::{ logical_plan::{Extension, UserDefinedLogicalNode}, Expr, LogicalPlan, }, }; use fmt::Debug; use pyo3::prelude::*; use crate::sql::{exceptions::py_type_err, logical}; #[derive(Clone, PartialEq)] pub struct AlterSchemaPlanNode { pub schema: DFSchemaRef, pub old_schema_name: String, pub new_schema_name: String, } impl Debug for AlterSchemaPlanNode { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.fmt_for_explain(f) } } impl Hash for AlterSchemaPlanNode { fn hash(&self, state: &mut H) { self.schema.hash(state); self.old_schema_name.hash(state); self.new_schema_name.hash(state); } } impl UserDefinedLogicalNode for AlterSchemaPlanNode { fn as_any(&self) -> &dyn Any { self } fn inputs(&self) -> Vec<&LogicalPlan> { vec![] } fn schema(&self) -> &DFSchemaRef { &self.schema } fn expressions(&self) -> Vec { // there is no need to expose any expressions here since DataFusion would // not be able to do anything with expressions that are specific to // ALTER SCHEMA {table_name} vec![] } fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, "Alter Schema: old_schema_name: {:?}, new_schema_name: {:?}", self.old_schema_name, self.new_schema_name ) } fn from_template( &self, _exprs: &[Expr], _inputs: &[LogicalPlan], ) -> Arc { Arc::new(AlterSchemaPlanNode { schema: Arc::new(DFSchema::empty()), old_schema_name: self.old_schema_name.clone(), new_schema_name: self.new_schema_name.clone(), }) } fn name(&self) -> &str { "AlterSchema" } fn dyn_hash(&self, state: &mut dyn Hasher) { let mut s = state; self.hash(&mut s); } fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool { match other.as_any().downcast_ref::() { Some(o) => self == o, None => false, } } } #[pyclass(name = "AlterSchema", module = "dask_sql", subclass)] pub struct PyAlterSchema { pub(crate) alter_schema: AlterSchemaPlanNode, } #[pymethods] impl PyAlterSchema { #[pyo3(name = "getOldSchemaName")] fn get_old_schema_name(&self) -> PyResult { Ok(self.alter_schema.old_schema_name.clone()) } #[pyo3(name = "getNewSchemaName")] fn get_new_schema_name(&self) -> PyResult { Ok(self.alter_schema.new_schema_name.clone()) } } impl TryFrom for PyAlterSchema { type Error = PyErr; fn try_from(logical_plan: logical::LogicalPlan) -> Result { match logical_plan { LogicalPlan::Extension(Extension { node }) if node .as_any() .downcast_ref::() .is_some() => { let ext = node .as_any() .downcast_ref::() .expect("AlterSchemaPlanNode"); Ok(PyAlterSchema { alter_schema: ext.clone(), }) } _ => Err(py_type_err("unexpected plan")), } } } ================================================ FILE: src/sql/logical/alter_table.rs ================================================ use std::{ any::Any, fmt, hash::{Hash, Hasher}, sync::Arc, }; use datafusion_python::{ datafusion_common::{DFSchema, DFSchemaRef}, datafusion_expr::{ logical_plan::{Extension, UserDefinedLogicalNode}, Expr, LogicalPlan, }, }; use fmt::Debug; use pyo3::prelude::*; use crate::sql::{exceptions::py_type_err, logical}; #[derive(Clone, PartialEq)] pub struct AlterTablePlanNode { pub schema: DFSchemaRef, pub old_table_name: String, pub new_table_name: String, pub schema_name: Option, pub if_exists: bool, } impl Debug for AlterTablePlanNode { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.fmt_for_explain(f) } } impl Hash for AlterTablePlanNode { fn hash(&self, state: &mut H) { self.schema.hash(state); self.old_table_name.hash(state); self.new_table_name.hash(state); self.schema_name.hash(state); self.if_exists.hash(state); } } impl UserDefinedLogicalNode for AlterTablePlanNode { fn as_any(&self) -> &dyn Any { self } fn inputs(&self) -> Vec<&LogicalPlan> { vec![] } fn schema(&self) -> &DFSchemaRef { &self.schema } fn expressions(&self) -> Vec { // there is no need to expose any expressions here since DataFusion would // not be able to do anything with expressions that are specific to // ALTER TABLE {table_name} vec![] } fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, "Alter Table: old_table_name: {:?}, new_table_name: {:?}, schema_name: {:?}", self.old_table_name, self.new_table_name, self.schema_name ) } fn from_template( &self, _exprs: &[Expr], _inputs: &[LogicalPlan], ) -> Arc { Arc::new(AlterTablePlanNode { schema: Arc::new(DFSchema::empty()), old_table_name: self.old_table_name.clone(), new_table_name: self.new_table_name.clone(), schema_name: self.schema_name.clone(), if_exists: self.if_exists, }) } fn name(&self) -> &str { "AlterTable" } fn dyn_hash(&self, state: &mut dyn Hasher) { let mut s = state; self.hash(&mut s); } fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool { match other.as_any().downcast_ref::() { Some(o) => self == o, None => false, } } } #[pyclass(name = "AlterTable", module = "dask_sql", subclass)] pub struct PyAlterTable { pub(crate) alter_table: AlterTablePlanNode, } #[pymethods] impl PyAlterTable { #[pyo3(name = "getOldTableName")] fn get_old_table_name(&self) -> PyResult { Ok(self.alter_table.old_table_name.clone()) } #[pyo3(name = "getNewTableName")] fn get_new_table_name(&self) -> PyResult { Ok(self.alter_table.new_table_name.clone()) } #[pyo3(name = "getSchemaName")] fn get_schema_name(&self) -> PyResult> { Ok(self.alter_table.schema_name.clone()) } #[pyo3(name = "getIfExists")] fn get_if_exists(&self) -> PyResult { Ok(self.alter_table.if_exists) } } impl TryFrom for PyAlterTable { type Error = PyErr; fn try_from(logical_plan: logical::LogicalPlan) -> Result { match logical_plan { LogicalPlan::Extension(Extension { node }) if node.as_any().downcast_ref::().is_some() => { let ext = node .as_any() .downcast_ref::() .expect("AlterTablePlanNode"); Ok(PyAlterTable { alter_table: ext.clone(), }) } _ => Err(py_type_err("unexpected plan")), } } } ================================================ FILE: src/sql/logical/analyze_table.rs ================================================ use std::{ any::Any, fmt, hash::{Hash, Hasher}, sync::Arc, }; use datafusion_python::{ datafusion_common::{DFSchema, DFSchemaRef}, datafusion_expr::{ logical_plan::{Extension, UserDefinedLogicalNode}, Expr, LogicalPlan, }, }; use fmt::Debug; use pyo3::prelude::*; use crate::sql::{exceptions::py_type_err, logical}; #[derive(Clone, PartialEq)] pub struct AnalyzeTablePlanNode { pub schema: DFSchemaRef, pub table_name: String, pub schema_name: Option, pub columns: Vec, } impl Debug for AnalyzeTablePlanNode { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.fmt_for_explain(f) } } impl Hash for AnalyzeTablePlanNode { fn hash(&self, state: &mut H) { self.schema.hash(state); self.table_name.hash(state); self.schema_name.hash(state); self.columns.hash(state); } } impl UserDefinedLogicalNode for AnalyzeTablePlanNode { fn as_any(&self) -> &dyn Any { self } fn inputs(&self) -> Vec<&LogicalPlan> { vec![] } fn schema(&self) -> &DFSchemaRef { &self.schema } fn expressions(&self) -> Vec { // there is no need to expose any expressions here since DataFusion would // not be able to do anything with expressions that are specific to // ANALYZE TABLE {table_name} vec![] } fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, "Analyze Table: table_name: {:?}, columns: {:?}", self.table_name, self.columns ) } fn from_template( &self, _exprs: &[Expr], _inputs: &[LogicalPlan], ) -> Arc { Arc::new(AnalyzeTablePlanNode { schema: Arc::new(DFSchema::empty()), table_name: self.table_name.clone(), schema_name: self.schema_name.clone(), columns: self.columns.clone(), }) } fn name(&self) -> &str { "AnalyzeTable" } fn dyn_hash(&self, state: &mut dyn Hasher) { let mut s = state; self.hash(&mut s); } fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool { match other.as_any().downcast_ref::() { Some(o) => self == o, None => false, } } } #[pyclass(name = "AnalyzeTable", module = "dask_sql", subclass)] pub struct PyAnalyzeTable { pub(crate) analyze_table: AnalyzeTablePlanNode, } #[pymethods] impl PyAnalyzeTable { #[pyo3(name = "getTableName")] fn get_table_name(&self) -> PyResult { Ok(self.analyze_table.table_name.clone()) } #[pyo3(name = "getSchemaName")] fn get_schema_name(&self) -> PyResult> { Ok(self.analyze_table.schema_name.clone()) } #[pyo3(name = "getColumns")] fn get_columns(&self) -> PyResult> { Ok(self.analyze_table.columns.clone()) } } impl TryFrom for PyAnalyzeTable { type Error = PyErr; fn try_from(logical_plan: logical::LogicalPlan) -> Result { match logical_plan { LogicalPlan::Extension(Extension { node }) if node .as_any() .downcast_ref::() .is_some() => { let ext = node .as_any() .downcast_ref::() .expect("AnalyzeTablePlanNode"); Ok(PyAnalyzeTable { analyze_table: ext.clone(), }) } _ => Err(py_type_err("unexpected plan")), } } } ================================================ FILE: src/sql/logical/create_catalog_schema.rs ================================================ use std::{ any::Any, fmt, hash::{Hash, Hasher}, sync::Arc, }; use datafusion_python::{ datafusion_common::{DFSchema, DFSchemaRef}, datafusion_expr::{logical_plan::UserDefinedLogicalNode, Expr, LogicalPlan}, }; use fmt::Debug; use pyo3::prelude::*; use crate::sql::{exceptions::py_type_err, logical}; #[derive(Clone, PartialEq)] pub struct CreateCatalogSchemaPlanNode { pub schema: DFSchemaRef, pub schema_name: String, pub if_not_exists: bool, pub or_replace: bool, } impl Debug for CreateCatalogSchemaPlanNode { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.fmt_for_explain(f) } } impl Hash for CreateCatalogSchemaPlanNode { fn hash(&self, state: &mut H) { self.schema.hash(state); self.schema_name.hash(state); self.if_not_exists.hash(state); self.or_replace.hash(state); } } impl UserDefinedLogicalNode for CreateCatalogSchemaPlanNode { fn as_any(&self) -> &dyn Any { self } fn inputs(&self) -> Vec<&LogicalPlan> { vec![] } fn schema(&self) -> &DFSchemaRef { &self.schema } fn expressions(&self) -> Vec { // there is no need to expose any expressions here since DataFusion would // not be able to do anything with expressions that are specific to // CREATE SCHEMA vec![] } fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, "CreateCatalogSchema: schema_name={}, or_replace={}, if_not_exists={}", self.schema_name, self.or_replace, self.if_not_exists ) } fn from_template( &self, _exprs: &[Expr], _inputs: &[LogicalPlan], ) -> Arc { Arc::new(CreateCatalogSchemaPlanNode { schema: Arc::new(DFSchema::empty()), schema_name: self.schema_name.clone(), if_not_exists: self.if_not_exists, or_replace: self.or_replace, }) } fn name(&self) -> &str { "CreateCatalogSchema" } fn dyn_hash(&self, state: &mut dyn Hasher) { let mut s = state; self.hash(&mut s); } fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool { match other.as_any().downcast_ref::() { Some(o) => self == o, None => false, } } } #[pyclass(name = "CreateCatalogSchema", module = "dask_sql", subclass)] pub struct PyCreateCatalogSchema { pub(crate) create_catalog_schema: CreateCatalogSchemaPlanNode, } #[pymethods] impl PyCreateCatalogSchema { #[pyo3(name = "getSchemaName")] fn get_schema_name(&self) -> PyResult { Ok(self.create_catalog_schema.schema_name.clone()) } #[pyo3(name = "getIfNotExists")] fn get_if_not_exists(&self) -> PyResult { Ok(self.create_catalog_schema.if_not_exists) } #[pyo3(name = "getReplace")] fn get_replace(&self) -> PyResult { Ok(self.create_catalog_schema.or_replace) } } impl TryFrom for PyCreateCatalogSchema { type Error = PyErr; fn try_from(logical_plan: logical::LogicalPlan) -> Result { match logical_plan { logical::LogicalPlan::Extension(extension) => { if let Some(ext) = extension .node .as_any() .downcast_ref::() { Ok(PyCreateCatalogSchema { create_catalog_schema: ext.clone(), }) } else { Err(py_type_err("unexpected plan")) } } _ => Err(py_type_err("unexpected plan")), } } } ================================================ FILE: src/sql/logical/create_experiment.rs ================================================ use std::{ any::Any, fmt, hash::{Hash, Hasher}, sync::Arc, }; use datafusion_python::{ datafusion_common::DFSchemaRef, datafusion_expr::{logical_plan::UserDefinedLogicalNode, Expr, LogicalPlan}, }; use fmt::Debug; use pyo3::prelude::*; use crate::{ parser::PySqlArg, sql::{exceptions::py_type_err, logical}, }; #[derive(Clone, PartialEq)] pub struct CreateExperimentPlanNode { pub schema_name: Option, pub experiment_name: String, pub input: LogicalPlan, pub if_not_exists: bool, pub or_replace: bool, pub with_options: Vec<(String, PySqlArg)>, } impl Debug for CreateExperimentPlanNode { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.fmt_for_explain(f) } } impl Hash for CreateExperimentPlanNode { fn hash(&self, state: &mut H) { self.schema_name.hash(state); self.experiment_name.hash(state); self.input.hash(state); self.if_not_exists.hash(state); self.or_replace.hash(state); // self.with_options.hash(state); } } impl UserDefinedLogicalNode for CreateExperimentPlanNode { fn as_any(&self) -> &dyn Any { self } fn inputs(&self) -> Vec<&LogicalPlan> { vec![&self.input] } fn schema(&self) -> &DFSchemaRef { self.input.schema() } fn expressions(&self) -> Vec { // there is no need to expose any expressions here since DataFusion would // not be able to do anything with expressions that are specific to // CREATE EXPERIMENT vec![] } fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, "CreateExperiment: experiment_name={}", self.experiment_name ) } fn from_template( &self, _exprs: &[Expr], inputs: &[LogicalPlan], ) -> Arc { assert_eq!(inputs.len(), 1, "input size inconsistent"); Arc::new(CreateExperimentPlanNode { schema_name: self.schema_name.clone(), experiment_name: self.experiment_name.clone(), input: inputs[0].clone(), if_not_exists: self.if_not_exists, or_replace: self.or_replace, with_options: self.with_options.clone(), }) } fn name(&self) -> &str { "CreateExperiment" } fn dyn_hash(&self, state: &mut dyn Hasher) { let mut s = state; self.hash(&mut s); } fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool { match other.as_any().downcast_ref::() { Some(o) => self == o, None => false, } } } #[pyclass(name = "CreateExperiment", module = "dask_sql", subclass)] pub struct PyCreateExperiment { pub(crate) create_experiment: CreateExperimentPlanNode, } #[pymethods] impl PyCreateExperiment { /// Creating an experiment requires that a subquery be passed to the CREATE EXPERIMENT /// statement to be used to gather the dataset which should be used for the /// experiment. This function returns that portion of the statement. #[pyo3(name = "getSelectQuery")] fn get_select_query(&self) -> PyResult { Ok(self.create_experiment.input.clone().into()) } #[pyo3(name = "getSchemaName")] fn get_schema_name(&self) -> PyResult> { Ok(self.create_experiment.schema_name.clone()) } #[pyo3(name = "getExperimentName")] fn get_experiment_name(&self) -> PyResult { Ok(self.create_experiment.experiment_name.clone()) } #[pyo3(name = "getIfNotExists")] fn get_if_not_exists(&self) -> PyResult { Ok(self.create_experiment.if_not_exists) } #[pyo3(name = "getOrReplace")] pub fn get_or_replace(&self) -> PyResult { Ok(self.create_experiment.or_replace) } #[pyo3(name = "getSQLWithOptions")] fn sql_with_options(&self) -> PyResult> { Ok(self.create_experiment.with_options.clone()) } } impl TryFrom for PyCreateExperiment { type Error = PyErr; fn try_from(logical_plan: logical::LogicalPlan) -> Result { match logical_plan { logical::LogicalPlan::Extension(extension) => { if let Some(ext) = extension .node .as_any() .downcast_ref::() { Ok(PyCreateExperiment { create_experiment: ext.clone(), }) } else { Err(py_type_err("unexpected plan")) } } _ => Err(py_type_err("unexpected plan")), } } } ================================================ FILE: src/sql/logical/create_memory_table.rs ================================================ use datafusion_python::datafusion_expr::{ logical_plan::{CreateMemoryTable, CreateView}, DdlStatement, LogicalPlan, }; use pyo3::prelude::*; use crate::sql::{exceptions::py_type_err, logical::PyLogicalPlan}; #[pyclass(name = "CreateMemoryTable", module = "dask_sql", subclass)] #[derive(Clone)] pub struct PyCreateMemoryTable { create_memory_table: Option, create_view: Option, } #[pymethods] impl PyCreateMemoryTable { #[pyo3(name = "getQualifiedName")] pub fn get_table_name(&self) -> PyResult { Ok(match &self.create_memory_table { Some(create_memory_table) => create_memory_table.name.to_string(), None => match &self.create_view { Some(create_view) => create_view.name.to_string(), None => { return Err(py_type_err( "Encountered a non CreateMemoryTable/CreateView type in get_input", )) } }, }) } #[pyo3(name = "getInput")] pub fn get_input(&self) -> PyResult { Ok(match &self.create_memory_table { Some(create_memory_table) => PyLogicalPlan { original_plan: (*create_memory_table.input).clone(), current_node: None, }, None => match &self.create_view { Some(create_view) => PyLogicalPlan { original_plan: (*create_view.input).clone(), current_node: None, }, None => { return Err(py_type_err( "Encountered a non CreateMemoryTable/CreateView type in get_input", )) } }, }) } #[pyo3(name = "getIfNotExists")] pub fn get_if_not_exists(&self) -> PyResult { Ok(match &self.create_memory_table { Some(create_memory_table) => create_memory_table.if_not_exists, None => false, // TODO: in the future we may want to set this based on dialect }) } #[pyo3(name = "getOrReplace")] pub fn get_or_replace(&self) -> PyResult { Ok(match &self.create_memory_table { Some(create_memory_table) => create_memory_table.or_replace, None => match &self.create_view { Some(create_view) => create_view.or_replace, None => { return Err(py_type_err( "Encountered a non CreateMemoryTable/CreateView type in get_input", )) } }, }) } #[pyo3(name = "isTable")] pub fn is_table(&self) -> PyResult { Ok(self.create_memory_table.is_some()) } } impl TryFrom for PyCreateMemoryTable { type Error = PyErr; fn try_from(logical_plan: LogicalPlan) -> Result { Ok(match logical_plan { LogicalPlan::Ddl(DdlStatement::CreateMemoryTable(cmt)) => PyCreateMemoryTable { create_memory_table: Some(cmt), create_view: None, }, LogicalPlan::Ddl(DdlStatement::CreateView(cv)) => PyCreateMemoryTable { create_memory_table: None, create_view: Some(cv), }, _ => return Err(py_type_err("unexpected plan")), }) } } ================================================ FILE: src/sql/logical/create_model.rs ================================================ use std::{ any::Any, fmt, hash::{Hash, Hasher}, sync::Arc, }; use datafusion_python::{ datafusion_common::DFSchemaRef, datafusion_expr::{logical_plan::UserDefinedLogicalNode, Expr, LogicalPlan}, }; use fmt::Debug; use pyo3::prelude::*; use crate::{ parser::PySqlArg, sql::{exceptions::py_type_err, logical}, }; #[derive(Clone, PartialEq)] pub struct CreateModelPlanNode { pub schema_name: Option, pub model_name: String, pub input: LogicalPlan, pub if_not_exists: bool, pub or_replace: bool, pub with_options: Vec<(String, PySqlArg)>, } impl Debug for CreateModelPlanNode { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.fmt_for_explain(f) } } impl Hash for CreateModelPlanNode { fn hash(&self, state: &mut H) { self.schema_name.hash(state); self.model_name.hash(state); self.input.hash(state); self.if_not_exists.hash(state); self.or_replace.hash(state); // self.with_options.hash(state); } } impl UserDefinedLogicalNode for CreateModelPlanNode { fn as_any(&self) -> &dyn Any { self } fn inputs(&self) -> Vec<&LogicalPlan> { vec![&self.input] } fn schema(&self) -> &DFSchemaRef { self.input.schema() } fn expressions(&self) -> Vec { // there is no need to expose any expressions here since DataFusion would // not be able to do anything with expressions that are specific to // CREATE MODEL vec![] } fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "CreateModel: model_name={}", self.model_name) } fn from_template( &self, _exprs: &[Expr], inputs: &[LogicalPlan], ) -> Arc { assert_eq!(inputs.len(), 1, "input size inconsistent"); Arc::new(CreateModelPlanNode { schema_name: self.schema_name.clone(), model_name: self.model_name.clone(), input: inputs[0].clone(), if_not_exists: self.if_not_exists, or_replace: self.or_replace, with_options: self.with_options.clone(), }) } fn name(&self) -> &str { "CreateModel" } fn dyn_hash(&self, state: &mut dyn Hasher) { let mut s = state; self.hash(&mut s); } fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool { match other.as_any().downcast_ref::() { Some(o) => self == o, None => false, } } } #[pyclass(name = "CreateModel", module = "dask_sql", subclass)] pub struct PyCreateModel { pub(crate) create_model: CreateModelPlanNode, } #[pymethods] impl PyCreateModel { /// Creating a model requires that a subquery be passed to the CREATE MODEL /// statement to be used to gather the dataset which should be used for the /// model. This function returns that portion of the statement. #[pyo3(name = "getSelectQuery")] fn get_select_query(&self) -> PyResult { Ok(self.create_model.input.clone().into()) } #[pyo3(name = "getSchemaName")] fn get_schema_name(&self) -> PyResult> { Ok(self.create_model.schema_name.clone()) } #[pyo3(name = "getModelName")] fn get_model_name(&self) -> PyResult { Ok(self.create_model.model_name.clone()) } #[pyo3(name = "getIfNotExists")] fn get_if_not_exists(&self) -> PyResult { Ok(self.create_model.if_not_exists) } #[pyo3(name = "getOrReplace")] pub fn get_or_replace(&self) -> PyResult { Ok(self.create_model.or_replace) } #[pyo3(name = "getSQLWithOptions")] fn sql_with_options(&self) -> PyResult> { Ok(self.create_model.with_options.clone()) } } impl TryFrom for PyCreateModel { type Error = PyErr; fn try_from(logical_plan: logical::LogicalPlan) -> Result { match logical_plan { logical::LogicalPlan::Extension(extension) => { if let Some(ext) = extension .node .as_any() .downcast_ref::() { Ok(PyCreateModel { create_model: ext.clone(), }) } else { Err(py_type_err("unexpected plan")) } } _ => Err(py_type_err("unexpected plan")), } } } ================================================ FILE: src/sql/logical/create_table.rs ================================================ use std::{ any::Any, fmt, hash::{Hash, Hasher}, sync::Arc, }; use datafusion_python::{ datafusion_common::{DFSchema, DFSchemaRef}, datafusion_expr::{logical_plan::UserDefinedLogicalNode, Expr, LogicalPlan}, }; use fmt::Debug; use pyo3::prelude::*; use crate::{ parser::PySqlArg, sql::{exceptions::py_type_err, logical}, }; #[derive(Clone, PartialEq)] pub struct CreateTablePlanNode { pub schema: DFSchemaRef, pub schema_name: Option, // "something" in `something.table_name` pub table_name: String, pub if_not_exists: bool, pub or_replace: bool, pub with_options: Vec<(String, PySqlArg)>, } impl Debug for CreateTablePlanNode { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.fmt_for_explain(f) } } impl Hash for CreateTablePlanNode { fn hash(&self, state: &mut H) { self.schema.hash(state); self.schema_name.hash(state); self.table_name.hash(state); self.if_not_exists.hash(state); self.or_replace.hash(state); // self.with_options.hash(state); } } impl UserDefinedLogicalNode for CreateTablePlanNode { fn as_any(&self) -> &dyn Any { self } fn inputs(&self) -> Vec<&LogicalPlan> { vec![] } fn schema(&self) -> &DFSchemaRef { &self.schema } fn expressions(&self) -> Vec { // there is no need to expose any expressions here since DataFusion would // not be able to do anything with expressions that are specific to // CREATE TABLE vec![] } fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "CreateTable: table_name={}", self.table_name) } fn from_template( &self, _exprs: &[Expr], _inputs: &[LogicalPlan], ) -> Arc { Arc::new(CreateTablePlanNode { schema: Arc::new(DFSchema::empty()), schema_name: self.schema_name.clone(), table_name: self.table_name.clone(), if_not_exists: self.if_not_exists, or_replace: self.or_replace, with_options: self.with_options.clone(), }) } fn name(&self) -> &str { "CreateTable" } fn dyn_hash(&self, state: &mut dyn Hasher) { let mut s = state; self.hash(&mut s); } fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool { match other.as_any().downcast_ref::() { Some(o) => self == o, None => false, } } } #[pyclass(name = "CreateTable", module = "dask_sql", subclass)] pub struct PyCreateTable { pub(crate) create_table: CreateTablePlanNode, } #[pymethods] impl PyCreateTable { #[pyo3(name = "getSchemaName")] fn get_schema_name(&self) -> PyResult> { Ok(self.create_table.schema_name.clone()) } #[pyo3(name = "getTableName")] fn get_table_name(&self) -> PyResult { Ok(self.create_table.table_name.clone()) } #[pyo3(name = "getIfNotExists")] fn get_if_not_exists(&self) -> PyResult { Ok(self.create_table.if_not_exists) } #[pyo3(name = "getOrReplace")] fn get_or_replace(&self) -> PyResult { Ok(self.create_table.or_replace) } #[pyo3(name = "getSQLWithOptions")] fn sql_with_options(&self) -> PyResult> { Ok(self.create_table.with_options.clone()) } } impl TryFrom for PyCreateTable { type Error = PyErr; fn try_from(logical_plan: logical::LogicalPlan) -> Result { match logical_plan { logical::LogicalPlan::Extension(extension) => { if let Some(ext) = extension .node .as_any() .downcast_ref::() { Ok(PyCreateTable { create_table: ext.clone(), }) } else { Err(py_type_err("unexpected plan")) } } _ => Err(py_type_err("unexpected plan")), } } } ================================================ FILE: src/sql/logical/describe_model.rs ================================================ use std::{ any::Any, fmt, hash::{Hash, Hasher}, sync::Arc, }; use datafusion_python::{ datafusion_common::{DFSchema, DFSchemaRef}, datafusion_expr::{logical_plan::UserDefinedLogicalNode, Expr, LogicalPlan}, }; use fmt::Debug; use pyo3::prelude::*; use crate::sql::{exceptions::py_type_err, logical}; #[derive(Clone, PartialEq)] pub struct DescribeModelPlanNode { pub schema: DFSchemaRef, pub schema_name: Option, pub model_name: String, } impl Debug for DescribeModelPlanNode { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.fmt_for_explain(f) } } impl Hash for DescribeModelPlanNode { fn hash(&self, state: &mut H) { self.schema.hash(state); self.schema_name.hash(state); self.model_name.hash(state); } } impl UserDefinedLogicalNode for DescribeModelPlanNode { fn as_any(&self) -> &dyn Any { self } fn inputs(&self) -> Vec<&LogicalPlan> { vec![] } fn schema(&self) -> &DFSchemaRef { &self.schema } fn expressions(&self) -> Vec { // there is no need to expose any expressions here since DataFusion would // not be able to do anything with expressions that are specific to // DESCRIBE MODEL vec![] } fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "DescribeModel: model_name={}", self.model_name) } fn from_template( &self, _exprs: &[Expr], inputs: &[LogicalPlan], ) -> Arc { assert_eq!(inputs.len(), 0, "input size inconsistent"); Arc::new(DescribeModelPlanNode { schema: Arc::new(DFSchema::empty()), schema_name: self.schema_name.clone(), model_name: self.model_name.clone(), }) } fn name(&self) -> &str { "DescribeModel" } fn dyn_hash(&self, state: &mut dyn Hasher) { let mut s = state; self.hash(&mut s); } fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool { match other.as_any().downcast_ref::() { Some(o) => self == o, None => false, } } } #[pyclass(name = "DescribeModel", module = "dask_sql", subclass)] pub struct PyDescribeModel { pub(crate) describe_model: DescribeModelPlanNode, } #[pymethods] impl PyDescribeModel { #[pyo3(name = "getSchemaName")] fn get_schema_name(&self) -> PyResult> { Ok(self.describe_model.schema_name.clone()) } #[pyo3(name = "getModelName")] fn get_model_name(&self) -> PyResult { Ok(self.describe_model.model_name.clone()) } } impl TryFrom for PyDescribeModel { type Error = PyErr; fn try_from(logical_plan: logical::LogicalPlan) -> Result { match logical_plan { logical::LogicalPlan::Extension(extension) => { if let Some(ext) = extension .node .as_any() .downcast_ref::() { Ok(PyDescribeModel { describe_model: ext.clone(), }) } else { Err(py_type_err("unexpected plan")) } } _ => Err(py_type_err("unexpected plan")), } } } ================================================ FILE: src/sql/logical/drop_model.rs ================================================ use std::{ any::Any, fmt, hash::{Hash, Hasher}, sync::Arc, }; use datafusion_python::{ datafusion_common::{DFSchema, DFSchemaRef}, datafusion_expr::{logical_plan::UserDefinedLogicalNode, Expr, LogicalPlan}, }; use fmt::Debug; use pyo3::prelude::*; use crate::sql::{exceptions::py_type_err, logical}; #[derive(Clone, PartialEq)] pub struct DropModelPlanNode { pub schema_name: Option, pub model_name: String, pub if_exists: bool, pub schema: DFSchemaRef, } impl Debug for DropModelPlanNode { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.fmt_for_explain(f) } } impl Hash for DropModelPlanNode { fn hash(&self, state: &mut H) { self.schema_name.hash(state); self.model_name.hash(state); self.if_exists.hash(state); self.schema.hash(state); } } impl UserDefinedLogicalNode for DropModelPlanNode { fn as_any(&self) -> &dyn Any { self } fn inputs(&self) -> Vec<&LogicalPlan> { vec![] } fn schema(&self) -> &DFSchemaRef { &self.schema } fn expressions(&self) -> Vec { // there is no need to expose any expressions here since DataFusion would // not be able to do anything with expressions that are specific to // DROP MODEL vec![] } fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "DropModel: model_name={}", self.model_name) } fn from_template( &self, _exprs: &[Expr], inputs: &[LogicalPlan], ) -> Arc { assert_eq!(inputs.len(), 0, "input size inconsistent"); Arc::new(DropModelPlanNode { schema_name: self.schema_name.clone(), model_name: self.model_name.clone(), if_exists: self.if_exists, schema: Arc::new(DFSchema::empty()), }) } fn name(&self) -> &str { "DropModel" } fn dyn_hash(&self, state: &mut dyn Hasher) { let mut s = state; self.hash(&mut s); } fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool { match other.as_any().downcast_ref::() { Some(o) => self == o, None => false, } } } #[pyclass(name = "DropModel", module = "dask_sql", subclass)] pub struct PyDropModel { pub(crate) drop_model: DropModelPlanNode, } #[pymethods] impl PyDropModel { #[pyo3(name = "getSchemaName")] fn get_schema_name(&self) -> PyResult> { Ok(self.drop_model.schema_name.clone()) } #[pyo3(name = "getModelName")] fn get_model_name(&self) -> PyResult { Ok(self.drop_model.model_name.clone()) } #[pyo3(name = "getIfExists")] pub fn get_if_exists(&self) -> PyResult { Ok(self.drop_model.if_exists) } } impl TryFrom for PyDropModel { type Error = PyErr; fn try_from(logical_plan: logical::LogicalPlan) -> Result { match logical_plan { logical::LogicalPlan::Extension(extension) => { if let Some(ext) = extension.node.as_any().downcast_ref::() { Ok(PyDropModel { drop_model: ext.clone(), }) } else { Err(py_type_err("unexpected plan")) } } _ => Err(py_type_err("unexpected plan")), } } } ================================================ FILE: src/sql/logical/drop_schema.rs ================================================ use std::{ any::Any, fmt, hash::{Hash, Hasher}, sync::Arc, }; use datafusion_python::{ datafusion_common::{DFSchema, DFSchemaRef}, datafusion_expr::{logical_plan::UserDefinedLogicalNode, Expr, LogicalPlan}, }; use fmt::Debug; use pyo3::prelude::*; use crate::sql::{exceptions::py_type_err, logical}; #[derive(Clone, PartialEq)] pub struct DropSchemaPlanNode { pub schema: DFSchemaRef, pub schema_name: String, pub if_exists: bool, } impl Debug for DropSchemaPlanNode { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.fmt_for_explain(f) } } impl Hash for DropSchemaPlanNode { fn hash(&self, state: &mut H) { self.schema.hash(state); self.schema_name.hash(state); self.if_exists.hash(state); } } impl UserDefinedLogicalNode for DropSchemaPlanNode { fn as_any(&self) -> &dyn Any { self } fn inputs(&self) -> Vec<&LogicalPlan> { vec![] } fn schema(&self) -> &DFSchemaRef { &self.schema } fn expressions(&self) -> Vec { // there is no need to expose any expressions here since DataFusion would // not be able to do anything with expressions that are specific to // DROP SCHEMA vec![] } fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "DropSchema: schema_name={}", self.schema_name) } fn from_template( &self, _exprs: &[Expr], _inputs: &[LogicalPlan], ) -> Arc { Arc::new(DropSchemaPlanNode { schema: Arc::new(DFSchema::empty()), schema_name: self.schema_name.clone(), if_exists: self.if_exists, }) } fn name(&self) -> &str { "DropSchema" } fn dyn_hash(&self, state: &mut dyn Hasher) { let mut s = state; self.hash(&mut s); } fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool { match other.as_any().downcast_ref::() { Some(o) => self == o, None => false, } } } #[pyclass(name = "DropSchema", module = "dask_sql", subclass)] pub struct PyDropSchema { pub(crate) drop_schema: DropSchemaPlanNode, } #[pymethods] impl PyDropSchema { #[pyo3(name = "getSchemaName")] fn get_schema_name(&self) -> PyResult { Ok(self.drop_schema.schema_name.clone()) } #[pyo3(name = "getIfExists")] fn get_if_exists(&self) -> PyResult { Ok(self.drop_schema.if_exists) } } impl TryFrom for PyDropSchema { type Error = PyErr; fn try_from(logical_plan: logical::LogicalPlan) -> Result { match logical_plan { logical::LogicalPlan::Extension(extension) => { if let Some(ext) = extension.node.as_any().downcast_ref::() { Ok(PyDropSchema { drop_schema: ext.clone(), }) } else { Err(py_type_err("unexpected plan")) } } _ => Err(py_type_err("unexpected plan")), } } } ================================================ FILE: src/sql/logical/drop_table.rs ================================================ use datafusion_python::datafusion_expr::{ logical_plan::{DropTable, LogicalPlan}, DdlStatement, }; use pyo3::prelude::*; use crate::sql::exceptions::py_type_err; #[pyclass(name = "DropTable", module = "dask_sql", subclass)] #[derive(Clone)] pub struct PyDropTable { drop_table: DropTable, } #[pymethods] impl PyDropTable { #[pyo3(name = "getQualifiedName")] pub fn get_name(&self) -> PyResult { Ok(self.drop_table.name.to_string()) } #[pyo3(name = "getIfExists")] pub fn get_if_exists(&self) -> PyResult { Ok(self.drop_table.if_exists) } } impl TryFrom for PyDropTable { type Error = PyErr; fn try_from(logical_plan: LogicalPlan) -> Result { match logical_plan { LogicalPlan::Ddl(DdlStatement::DropTable(drop_table)) => Ok(PyDropTable { drop_table }), _ => Err(py_type_err("unexpected plan")), } } } ================================================ FILE: src/sql/logical/empty_relation.rs ================================================ use datafusion_python::datafusion_expr::logical_plan::{EmptyRelation, LogicalPlan}; use pyo3::prelude::*; use crate::sql::exceptions::py_type_err; #[pyclass(name = "EmptyRelation", module = "dask_sql", subclass)] #[derive(Clone)] pub struct PyEmptyRelation { empty_relation: EmptyRelation, } impl TryFrom for PyEmptyRelation { type Error = PyErr; fn try_from(logical_plan: LogicalPlan) -> Result { match logical_plan { LogicalPlan::EmptyRelation(empty_relation) => Ok(PyEmptyRelation { empty_relation }), _ => Err(py_type_err("unexpected plan")), } } } #[pymethods] impl PyEmptyRelation { /// Even though a relation results in an "empty" table column names /// will still be projected and must be captured in order to present /// the expected output to the user. This logic captures the names /// of those columns and returns them to the Python logic where /// there are rendered to the user #[pyo3(name = "emptyColumnNames")] pub fn empty_column_names(&self) -> PyResult> { Ok(self.empty_relation.schema.field_names()) } } ================================================ FILE: src/sql/logical/explain.rs ================================================ use datafusion_python::datafusion_expr::{logical_plan::Explain, LogicalPlan}; use pyo3::prelude::*; use crate::sql::exceptions::py_type_err; #[pyclass(name = "Explain", module = "dask_sql", subclass)] #[derive(Clone)] pub struct PyExplain { explain: Explain, } #[pymethods] impl PyExplain { /// Returns explain strings #[pyo3(name = "getExplainString")] pub fn get_explain_string(&self) -> PyResult> { let mut string_plans: Vec = Vec::new(); for stringified_plan in &self.explain.stringified_plans { string_plans.push((*stringified_plan.plan).clone()); } Ok(string_plans) } } impl TryFrom for PyExplain { type Error = PyErr; fn try_from(logical_plan: LogicalPlan) -> Result { match logical_plan { LogicalPlan::Explain(explain) => Ok(PyExplain { explain }), _ => Err(py_type_err("unexpected plan")), } } } ================================================ FILE: src/sql/logical/export_model.rs ================================================ use std::{ any::Any, fmt, hash::{Hash, Hasher}, sync::Arc, }; use datafusion_python::{ datafusion_common::{DFSchema, DFSchemaRef}, datafusion_expr::{logical_plan::UserDefinedLogicalNode, Expr, LogicalPlan}, }; use fmt::Debug; use pyo3::prelude::*; use crate::{ parser::PySqlArg, sql::{exceptions::py_type_err, logical}, }; #[derive(Clone, PartialEq)] pub struct ExportModelPlanNode { pub schema: DFSchemaRef, pub schema_name: Option, pub model_name: String, pub with_options: Vec<(String, PySqlArg)>, } impl Debug for ExportModelPlanNode { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.fmt_for_explain(f) } } impl Hash for ExportModelPlanNode { fn hash(&self, state: &mut H) { self.schema.hash(state); self.schema_name.hash(state); self.model_name.hash(state); // self.with_options.hash(state); } } impl UserDefinedLogicalNode for ExportModelPlanNode { fn as_any(&self) -> &dyn Any { self } fn inputs(&self) -> Vec<&LogicalPlan> { vec![] } fn schema(&self) -> &DFSchemaRef { &self.schema } fn expressions(&self) -> Vec { // there is no need to expose any expressions here since DataFusion would // not be able to do anything with expressions that are specific to // EXPORT MODEL vec![] } fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "ExportModel: model_name={}", self.model_name) } fn from_template( &self, _exprs: &[Expr], inputs: &[LogicalPlan], ) -> Arc { assert_eq!(inputs.len(), 0, "input size inconsistent"); Arc::new(ExportModelPlanNode { schema: Arc::new(DFSchema::empty()), schema_name: self.schema_name.clone(), model_name: self.model_name.clone(), with_options: self.with_options.clone(), }) } fn name(&self) -> &str { "ExportModel" } fn dyn_hash(&self, state: &mut dyn Hasher) { let mut s = state; self.hash(&mut s); } fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool { match other.as_any().downcast_ref::() { Some(o) => self == o, None => false, } } } #[pyclass(name = "ExportModel", module = "dask_sql", subclass)] pub struct PyExportModel { pub(crate) export_model: ExportModelPlanNode, } #[pymethods] impl PyExportModel { #[pyo3(name = "getModelName")] fn get_model_name(&self) -> PyResult { Ok(self.export_model.model_name.clone()) } #[pyo3(name = "getSchemaName")] fn get_schema_name(&self) -> PyResult> { Ok(self.export_model.schema_name.clone()) } #[pyo3(name = "getSQLWithOptions")] fn sql_with_options(&self) -> PyResult> { Ok(self.export_model.with_options.clone()) } } impl TryFrom for PyExportModel { type Error = PyErr; fn try_from(logical_plan: logical::LogicalPlan) -> Result { match logical_plan { logical::LogicalPlan::Extension(extension) => { if let Some(ext) = extension .node .as_any() .downcast_ref::() { Ok(PyExportModel { export_model: ext.clone(), }) } else { Err(py_type_err("unexpected plan")) } } _ => Err(py_type_err("unexpected plan")), } } } ================================================ FILE: src/sql/logical/filter.rs ================================================ use datafusion_python::datafusion_expr::{logical_plan::Filter, LogicalPlan}; use pyo3::prelude::*; use crate::{expression::PyExpr, sql::exceptions::py_type_err}; #[pyclass(name = "Filter", module = "dask_sql", subclass)] #[derive(Clone)] pub struct PyFilter { filter: Filter, } #[pymethods] impl PyFilter { /// LogicalPlan::Filter: The PyExpr, predicate, that represents the filtering condition #[pyo3(name = "getCondition")] pub fn get_condition(&mut self) -> PyResult { Ok(PyExpr::from( self.filter.predicate.clone(), Some(vec![self.filter.input.clone()]), )) } } impl TryFrom for PyFilter { type Error = PyErr; fn try_from(logical_plan: LogicalPlan) -> Result { match logical_plan { LogicalPlan::Filter(filter) => Ok(PyFilter { filter }), _ => Err(py_type_err("unexpected plan")), } } } ================================================ FILE: src/sql/logical/join.rs ================================================ use datafusion_python::{ datafusion_common::Column, datafusion_expr::{ and, logical_plan::{Join, JoinType, LogicalPlan}, BinaryExpr, Expr, Operator, }, }; use pyo3::prelude::*; use crate::{ expression::PyExpr, sql::{column, exceptions::py_type_err}, }; #[pyclass(name = "Join", module = "dask_sql", subclass)] #[derive(Clone)] pub struct PyJoin { join: Join, } #[pymethods] impl PyJoin { #[pyo3(name = "getCondition")] pub fn join_condition(&self) -> PyResult> { // equi-join filters let mut filters: Vec = self .join .on .iter() .map(|(l, r)| match (l, r) { (Expr::Column(l), Expr::Column(r)) => { Ok(Expr::Column(l.clone()).eq(Expr::Column(r.clone()))) } (Expr::Column(l), Expr::Cast(cast)) => { let right = Column::from_qualified_name(cast.expr.to_string()); Ok(Expr::Column(l.clone()).eq(Expr::Column(right))) } (Expr::Column(l), Expr::BinaryExpr(bin_expr)) => { Ok(Expr::BinaryExpr(BinaryExpr::new( Box::new(Expr::Column(l.clone())), Operator::Eq, Box::new(Expr::BinaryExpr(bin_expr.clone())), ))) } _ => Err(py_type_err(format!( "unsupported join condition. Left: {l} - Right: {r}" ))), }) .collect::, _>>()?; // other filter conditions if let Some(filter) = &self.join.filter { filters.push(filter.clone()); } if !filters.is_empty() { let root_expr = filters[1..] .iter() .fold(filters[0].clone(), |acc, expr| and(acc, expr.clone())); Ok(Some(PyExpr::from( root_expr, Some(vec![self.join.left.clone(), self.join.right.clone()]), ))) } else { Ok(None) } } #[pyo3(name = "getJoinConditions")] pub fn join_conditions(&mut self) -> PyResult> { // let lhs_table_name = match &*self.join.left { // LogicalPlan::TableScan(scan) => scan.table_name.clone(), // _ => { // return Err(py_type_err( // "lhs Expected TableScan but something else was received!", // )) // } // }; // let rhs_table_name = match &*self.join.right { // LogicalPlan::TableScan(scan) => scan.table_name.clone(), // _ => { // return Err(py_type_err( // "rhs Expected TableScan but something else was received!", // )) // } // }; let mut join_conditions: Vec<(column::PyColumn, column::PyColumn)> = Vec::new(); for (lhs, rhs) in self.join.on.clone() { match (lhs, rhs) { (Expr::Column(lhs), Expr::Column(rhs)) => { join_conditions.push((lhs.into(), rhs.into())); } _ => return Err(py_type_err("unsupported join condition")), } } Ok(join_conditions) } /// Returns the type of join represented by this LogicalPlan::Join instance #[pyo3(name = "getJoinType")] pub fn join_type(&mut self) -> PyResult { match self.join.join_type { JoinType::Inner => Ok("INNER".to_string()), JoinType::Left => Ok("LEFT".to_string()), JoinType::Right => Ok("RIGHT".to_string()), JoinType::Full => Ok("FULL".to_string()), JoinType::LeftSemi => Ok("LEFTSEMI".to_string()), JoinType::LeftAnti => Ok("LEFTANTI".to_string()), JoinType::RightSemi => Ok("RIGHTSEMI".to_string()), JoinType::RightAnti => Ok("RIGHTANTI".to_string()), } } } impl TryFrom for PyJoin { type Error = PyErr; fn try_from(logical_plan: LogicalPlan) -> Result { match logical_plan { LogicalPlan::Join(join) => Ok(PyJoin { join }), _ => Err(py_type_err("unexpected plan")), } } } ================================================ FILE: src/sql/logical/limit.rs ================================================ use datafusion_python::{ datafusion_common::ScalarValue, datafusion_expr::{logical_plan::Limit, Expr, LogicalPlan}, }; use pyo3::prelude::*; use crate::{expression::PyExpr, sql::exceptions::py_type_err}; #[pyclass(name = "Limit", module = "dask_sql", subclass)] #[derive(Clone)] pub struct PyLimit { limit: Limit, } #[pymethods] impl PyLimit { /// `OFFSET` specified in the query #[pyo3(name = "getSkip")] pub fn skip(&self) -> PyResult { Ok(PyExpr::from( Expr::Literal(ScalarValue::UInt64(Some(self.limit.skip as u64))), Some(vec![self.limit.input.clone()]), )) } /// `LIMIT` specified in the query #[pyo3(name = "getFetch")] pub fn fetch(&self) -> PyResult { Ok(PyExpr::from( Expr::Literal(ScalarValue::UInt64(Some( self.limit.fetch.unwrap_or(0) as u64 ))), Some(vec![self.limit.input.clone()]), )) } } impl TryFrom for PyLimit { type Error = PyErr; fn try_from(logical_plan: LogicalPlan) -> Result { match logical_plan { LogicalPlan::Limit(limit) => Ok(PyLimit { limit }), _ => Err(py_type_err("unexpected plan")), } } } ================================================ FILE: src/sql/logical/predict_model.rs ================================================ use std::{ any::Any, fmt, hash::{Hash, Hasher}, sync::Arc, }; use datafusion_python::{ datafusion_common::DFSchemaRef, datafusion_expr::{logical_plan::UserDefinedLogicalNode, Expr, LogicalPlan}, }; use fmt::Debug; use pyo3::prelude::*; use super::PyLogicalPlan; use crate::sql::{exceptions::py_type_err, logical}; #[derive(Clone, PartialEq)] pub struct PredictModelPlanNode { pub schema_name: Option, // "something" in `something.model_name` pub model_name: String, pub input: LogicalPlan, } impl Debug for PredictModelPlanNode { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.fmt_for_explain(f) } } impl Hash for PredictModelPlanNode { fn hash(&self, state: &mut H) { self.schema_name.hash(state); self.model_name.hash(state); self.input.hash(state); } } impl UserDefinedLogicalNode for PredictModelPlanNode { fn as_any(&self) -> &dyn Any { self } fn inputs(&self) -> Vec<&LogicalPlan> { vec![&self.input] } fn schema(&self) -> &DFSchemaRef { self.input.schema() } fn expressions(&self) -> Vec { // there is no need to expose any expressions here since DataFusion would // not be able to do anything with expressions that are specific to // PREDICT TABLE vec![] } fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "PredictModel: model_name={}", self.model_name) } fn from_template( &self, _exprs: &[Expr], inputs: &[LogicalPlan], ) -> Arc { Arc::new(PredictModelPlanNode { schema_name: self.schema_name.clone(), model_name: self.model_name.clone(), input: inputs[0].clone(), }) } fn name(&self) -> &str { "PredictModel" } fn dyn_hash(&self, state: &mut dyn Hasher) { let mut s = state; self.hash(&mut s); } fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool { match other.as_any().downcast_ref::() { Some(o) => self == o, None => false, } } } #[pyclass(name = "PredictModel", module = "dask_sql", subclass)] pub struct PyPredictModel { pub(crate) predict_model: PredictModelPlanNode, } #[pymethods] impl PyPredictModel { #[pyo3(name = "getSchemaName")] fn get_schema_name(&self) -> PyResult> { Ok(self.predict_model.schema_name.clone()) } #[pyo3(name = "getModelName")] fn get_model_name(&self) -> PyResult { Ok(self.predict_model.model_name.clone()) } #[pyo3(name = "getSelect")] fn get_select(&self) -> PyResult { Ok(PyLogicalPlan::from(self.predict_model.input.clone())) } } impl TryFrom for PyPredictModel { type Error = PyErr; fn try_from(logical_plan: logical::LogicalPlan) -> Result { match logical_plan { logical::LogicalPlan::Extension(extension) => { if let Some(ext) = extension .node .as_any() .downcast_ref::() { Ok(PyPredictModel { predict_model: ext.clone(), }) } else { Err(py_type_err("unexpected plan")) } } _ => Err(py_type_err("unexpected plan")), } } } ================================================ FILE: src/sql/logical/projection.rs ================================================ use datafusion_python::datafusion_expr::{ expr::Alias, logical_plan::Projection, Expr, LogicalPlan, }; use pyo3::prelude::*; use crate::{expression::PyExpr, sql::exceptions::py_type_err}; #[pyclass(name = "Projection", module = "dask_sql", subclass)] #[derive(Clone)] pub struct PyProjection { pub(crate) projection: Projection, } impl PyProjection { /// Projection: Gets the names of the fields that should be projected fn projected_expressions(&mut self, local_expr: &PyExpr) -> Vec { let mut projs: Vec = Vec::new(); match &local_expr.expr { Expr::Alias(Alias { expr, .. }) => { let py_expr: PyExpr = PyExpr::from(*expr.clone(), Some(vec![self.projection.input.clone()])); projs.extend_from_slice(self.projected_expressions(&py_expr).as_slice()); } _ => projs.push(local_expr.clone()), } projs } } #[pymethods] impl PyProjection { #[pyo3(name = "getNamedProjects")] fn named_projects(&mut self) -> PyResult> { let mut named: Vec<(String, PyExpr)> = Vec::new(); for expression in self.projection.expr.clone() { let py_expr: PyExpr = PyExpr::from(expression, Some(vec![self.projection.input.clone()])); for expr in self.projected_expressions(&py_expr) { match expr.expr { Expr::Alias(Alias { expr, name }) => named.push(( name.to_string(), PyExpr::from(*expr, Some(vec![self.projection.input.clone()])), )), _ => { if let Ok(name) = expr._column_name(&self.projection.input) { named.push((name, expr.clone())); } } } } } Ok(named) } } impl TryFrom for PyProjection { type Error = PyErr; fn try_from(logical_plan: LogicalPlan) -> Result { match logical_plan { LogicalPlan::Projection(projection) => Ok(PyProjection { projection }), _ => Err(py_type_err("unexpected plan")), } } } ================================================ FILE: src/sql/logical/repartition_by.rs ================================================ use datafusion_python::datafusion_expr::{ logical_plan::{Partitioning, Repartition}, Expr, LogicalPlan, }; use pyo3::prelude::*; use crate::{ expression::PyExpr, sql::{exceptions::py_type_err, logical}, }; #[pyclass(name = "RepartitionBy", module = "dask_sql", subclass)] pub struct PyRepartitionBy { pub(crate) repartition: Repartition, } #[pymethods] impl PyRepartitionBy { #[pyo3(name = "getSelectQuery")] fn get_select_query(&self) -> PyResult { let log_plan = &*(self.repartition.input).clone(); Ok(log_plan.clone().into()) } #[pyo3(name = "getDistributeList")] fn get_distribute_list(&self) -> PyResult> { match &self.repartition.partitioning_scheme { Partitioning::DistributeBy(distribute_list) => Ok(distribute_list .iter() .map(|e| PyExpr::from(e.clone(), Some(vec![self.repartition.input.clone()]))) .collect()), _ => Err(py_type_err("unexpected repartition strategy")), } } #[pyo3(name = "getDistributionColumns")] fn get_distribute_columns(&self) -> PyResult { match &self.repartition.partitioning_scheme { Partitioning::DistributeBy(distribute_list) => Ok(distribute_list .iter() .map(|e| match &e { Expr::Column(column) => column.name.clone(), _ => panic!("Encountered a type other than Expr::Column"), }) .collect()), _ => Err(py_type_err("unexpected repartition strategy")), } } } impl TryFrom for PyRepartitionBy { type Error = PyErr; fn try_from(logical_plan: LogicalPlan) -> Result { match logical_plan { LogicalPlan::Repartition(repartition) => Ok(PyRepartitionBy { repartition }), _ => Err(py_type_err("unexpected plan")), } } } ================================================ FILE: src/sql/logical/show_columns.rs ================================================ use std::{ any::Any, fmt, hash::{Hash, Hasher}, sync::Arc, }; use datafusion_python::{ datafusion_common::{DFSchema, DFSchemaRef}, datafusion_expr::{ logical_plan::{Extension, UserDefinedLogicalNode}, Expr, LogicalPlan, }, }; use fmt::Debug; use pyo3::prelude::*; use crate::sql::{exceptions::py_type_err, logical}; #[derive(Clone, PartialEq)] pub struct ShowColumnsPlanNode { pub schema: DFSchemaRef, pub table_name: String, pub schema_name: Option, } impl Debug for ShowColumnsPlanNode { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.fmt_for_explain(f) } } impl Hash for ShowColumnsPlanNode { fn hash(&self, state: &mut H) { self.schema.hash(state); self.table_name.hash(state); self.schema_name.hash(state); } } impl UserDefinedLogicalNode for ShowColumnsPlanNode { fn as_any(&self) -> &dyn Any { self } fn inputs(&self) -> Vec<&LogicalPlan> { vec![] } fn schema(&self) -> &DFSchemaRef { &self.schema } fn expressions(&self) -> Vec { // there is no need to expose any expressions here since DataFusion would // not be able to do anything with expressions that are specific to // SHOW COLUMNS FROM {table_name} vec![] } fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "Show Columns: table_name: {:?}", self.table_name) } fn from_template( &self, _exprs: &[Expr], _inputs: &[LogicalPlan], ) -> Arc { Arc::new(ShowColumnsPlanNode { schema: Arc::new(DFSchema::empty()), table_name: self.table_name.clone(), schema_name: self.schema_name.clone(), }) } fn name(&self) -> &str { "ShowColumns" } fn dyn_hash(&self, state: &mut dyn Hasher) { let mut s = state; self.hash(&mut s); } fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool { match other.as_any().downcast_ref::() { Some(o) => self == o, None => false, } } } #[pyclass(name = "ShowColumns", module = "dask_sql", subclass)] pub struct PyShowColumns { pub(crate) show_columns: ShowColumnsPlanNode, } #[pymethods] impl PyShowColumns { #[pyo3(name = "getTableName")] fn get_table_name(&self) -> PyResult { Ok(self.show_columns.table_name.clone()) } #[pyo3(name = "getSchemaName")] fn get_schema_name(&self) -> PyResult> { Ok(self.show_columns.schema_name.clone()) } } impl TryFrom for PyShowColumns { type Error = PyErr; fn try_from(logical_plan: logical::LogicalPlan) -> Result { match logical_plan { LogicalPlan::Extension(Extension { node }) if node .as_any() .downcast_ref::() .is_some() => { let ext = node .as_any() .downcast_ref::() .expect("ShowColumnsPlanNode"); Ok(PyShowColumns { show_columns: ext.clone(), }) } _ => Err(py_type_err("unexpected plan")), } } } ================================================ FILE: src/sql/logical/show_models.rs ================================================ use std::{ any::Any, fmt, hash::{Hash, Hasher}, sync::Arc, }; use datafusion_python::{ datafusion_common::{DFSchema, DFSchemaRef}, datafusion_expr::{logical_plan::UserDefinedLogicalNode, Expr, LogicalPlan}, }; use fmt::Debug; use pyo3::prelude::*; use crate::sql::logical::py_type_err; #[derive(Clone, PartialEq)] pub struct ShowModelsPlanNode { pub schema: DFSchemaRef, pub schema_name: Option, } impl Debug for ShowModelsPlanNode { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.fmt_for_explain(f) } } impl Hash for ShowModelsPlanNode { fn hash(&self, state: &mut H) { self.schema.hash(state); self.schema_name.hash(state); } } impl UserDefinedLogicalNode for ShowModelsPlanNode { fn as_any(&self) -> &dyn Any { self } fn inputs(&self) -> Vec<&LogicalPlan> { vec![] } fn schema(&self) -> &DFSchemaRef { &self.schema } fn expressions(&self) -> Vec { // there is no need to expose any expressions here since DataFusion would // not be able to do anything with expressions that are specific to // SHOW MODELS vec![] } fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "ShowModels") } fn from_template( &self, _exprs: &[Expr], _inputs: &[LogicalPlan], ) -> Arc { Arc::new(ShowModelsPlanNode { schema: Arc::new(DFSchema::empty()), schema_name: self.schema_name.clone(), }) } fn name(&self) -> &str { "ShowModels" } fn dyn_hash(&self, state: &mut dyn Hasher) { let mut s = state; self.hash(&mut s); } fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool { match other.as_any().downcast_ref::() { Some(o) => self == o, None => false, } } } #[pyclass(name = "ShowModels", module = "dask_sql", subclass)] pub struct PyShowModels { pub(crate) show_models: ShowModelsPlanNode, } #[pymethods] impl PyShowModels { #[pyo3(name = "getSchemaName")] fn get_schema_name(&self) -> PyResult> { Ok(self.show_models.schema_name.clone()) } } impl TryFrom for PyShowModels { type Error = PyErr; fn try_from(logical_plan: LogicalPlan) -> Result { match logical_plan { LogicalPlan::Extension(extension) => { if let Some(ext) = extension.node.as_any().downcast_ref::() { Ok(PyShowModels { show_models: ext.clone(), }) } else { Err(py_type_err("unexpected plan")) } } _ => Err(py_type_err("unexpected plan")), } } } ================================================ FILE: src/sql/logical/show_schemas.rs ================================================ use std::{ any::Any, fmt, hash::{Hash, Hasher}, sync::Arc, }; use datafusion_python::{ datafusion_common::{DFSchema, DFSchemaRef}, datafusion_expr::{ logical_plan::{Extension, UserDefinedLogicalNode}, Expr, LogicalPlan, }, }; use fmt::Debug; use pyo3::prelude::*; use crate::sql::{exceptions::py_type_err, logical}; #[derive(Clone, PartialEq)] pub struct ShowSchemasPlanNode { pub schema: DFSchemaRef, pub catalog_name: Option, pub like: Option, } impl Debug for ShowSchemasPlanNode { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.fmt_for_explain(f) } } impl Hash for ShowSchemasPlanNode { fn hash(&self, state: &mut H) { self.schema.hash(state); self.like.hash(state); } } impl UserDefinedLogicalNode for ShowSchemasPlanNode { fn as_any(&self) -> &dyn Any { self } fn inputs(&self) -> Vec<&LogicalPlan> { vec![] } fn schema(&self) -> &DFSchemaRef { &self.schema } fn expressions(&self) -> Vec { // there is no need to expose any expressions here since DataFusion would // not be able to do anything with expressions that are specific to // SHOW SCHEMAS vec![] } fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "ShowSchema: catalog_name: {:?}", self.catalog_name) } fn from_template( &self, _exprs: &[Expr], _inputs: &[LogicalPlan], ) -> Arc { Arc::new(ShowSchemasPlanNode { schema: Arc::new(DFSchema::empty()), catalog_name: self.catalog_name.clone(), like: self.like.clone(), }) } fn name(&self) -> &str { "ShowSchema" } fn dyn_hash(&self, state: &mut dyn Hasher) { let mut s = state; self.hash(&mut s); } fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool { match other.as_any().downcast_ref::() { Some(o) => self == o, None => false, } } } #[pyclass(name = "ShowSchema", module = "dask_sql", subclass)] pub struct PyShowSchema { pub(crate) show_schema: ShowSchemasPlanNode, } #[pymethods] impl PyShowSchema { #[pyo3(name = "getCatalogName")] fn get_from(&self) -> PyResult> { Ok(self.show_schema.catalog_name.clone()) } #[pyo3(name = "getLike")] fn get_like(&self) -> PyResult> { Ok(self.show_schema.like.clone()) } } impl TryFrom for PyShowSchema { type Error = PyErr; fn try_from(logical_plan: logical::LogicalPlan) -> Result { match logical_plan { LogicalPlan::Extension(Extension { node }) if node .as_any() .downcast_ref::() .is_some() => { let ext = node .as_any() .downcast_ref::() .expect("ShowSchemasPlanNode"); Ok(PyShowSchema { show_schema: ext.clone(), }) } _ => Err(py_type_err("unexpected plan")), } } } ================================================ FILE: src/sql/logical/show_tables.rs ================================================ use std::{ any::Any, fmt, hash::{Hash, Hasher}, sync::Arc, }; use datafusion_python::{ datafusion_common::{DFSchema, DFSchemaRef}, datafusion_expr::{ logical_plan::{Extension, UserDefinedLogicalNode}, Expr, LogicalPlan, }, }; use fmt::Debug; use pyo3::prelude::*; use crate::sql::{exceptions::py_type_err, logical}; #[derive(Clone, PartialEq)] pub struct ShowTablesPlanNode { pub schema: DFSchemaRef, pub catalog_name: Option, pub schema_name: Option, } impl Debug for ShowTablesPlanNode { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.fmt_for_explain(f) } } impl Hash for ShowTablesPlanNode { fn hash(&self, state: &mut H) { self.schema.hash(state); self.schema_name.hash(state); } } impl UserDefinedLogicalNode for ShowTablesPlanNode { fn as_any(&self) -> &dyn Any { self } fn inputs(&self) -> Vec<&LogicalPlan> { vec![] } fn schema(&self) -> &DFSchemaRef { &self.schema } fn expressions(&self) -> Vec { // there is no need to expose any expressions here since DataFusion would // not be able to do anything with expressions that are specific to // SHOW TABLES FROM {schema_name} vec![] } fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, "ShowTables: catalog_name: {:?}, schema_name: {:?}", self.catalog_name, self.schema_name ) } fn from_template( &self, _exprs: &[Expr], _inputs: &[LogicalPlan], ) -> Arc { Arc::new(ShowTablesPlanNode { schema: Arc::new(DFSchema::empty()), catalog_name: self.catalog_name.clone(), schema_name: self.schema_name.clone(), }) } fn name(&self) -> &str { "ShowTables" } fn dyn_hash(&self, state: &mut dyn Hasher) { let mut s = state; self.hash(&mut s); } fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool { match other.as_any().downcast_ref::() { Some(o) => self == o, None => false, } } } #[pyclass(name = "ShowTables", module = "dask_sql", subclass)] pub struct PyShowTables { pub(crate) show_tables: ShowTablesPlanNode, } #[pymethods] impl PyShowTables { #[pyo3(name = "getCatalogName")] fn get_catalog_name(&self) -> PyResult> { Ok(self.show_tables.catalog_name.clone()) } #[pyo3(name = "getSchemaName")] fn get_schema_name(&self) -> PyResult> { Ok(self.show_tables.schema_name.clone()) } } impl TryFrom for PyShowTables { type Error = PyErr; fn try_from(logical_plan: logical::LogicalPlan) -> Result { match logical_plan { LogicalPlan::Extension(Extension { node }) if node.as_any().downcast_ref::().is_some() => { let ext = node .as_any() .downcast_ref::() .expect("ShowTablesPlanNode"); Ok(PyShowTables { show_tables: ext.clone(), }) } _ => Err(py_type_err("unexpected plan")), } } } ================================================ FILE: src/sql/logical/sort.rs ================================================ use datafusion_python::datafusion_expr::{logical_plan::Sort, LogicalPlan}; use pyo3::prelude::*; use crate::{ expression::{py_expr_list, PyExpr}, sql::exceptions::py_type_err, }; #[pyclass(name = "Sort", module = "dask_sql", subclass)] #[derive(Clone)] pub struct PySort { sort: Sort, } #[pymethods] impl PySort { /// Returns a Vec of the sort expressions #[pyo3(name = "getCollation")] pub fn sort_expressions(&self) -> PyResult> { py_expr_list(&self.sort.input, &self.sort.expr) } #[pyo3(name = "getNumRows")] pub fn get_fetch_val(&self) -> PyResult> { Ok(self.sort.fetch) } } impl TryFrom for PySort { type Error = PyErr; fn try_from(logical_plan: LogicalPlan) -> Result { match logical_plan { LogicalPlan::Sort(sort) => Ok(PySort { sort }), _ => Err(py_type_err("unexpected plan")), } } } ================================================ FILE: src/sql/logical/subquery_alias.rs ================================================ use datafusion_python::datafusion_expr::{logical_plan::SubqueryAlias, LogicalPlan}; use pyo3::prelude::*; use crate::sql::exceptions::py_type_err; #[pyclass(name = "SubqueryAlias", module = "dask_sql", subclass)] #[derive(Clone)] pub struct PySubqueryAlias { subquery_alias: SubqueryAlias, } #[pymethods] impl PySubqueryAlias { /// Returns a Vec of the sort expressions #[pyo3(name = "getAlias")] pub fn alias(&self) -> PyResult { Ok(self.subquery_alias.alias.clone().to_string()) } } impl TryFrom for PySubqueryAlias { type Error = PyErr; fn try_from(logical_plan: LogicalPlan) -> Result { match logical_plan { LogicalPlan::SubqueryAlias(subquery_alias) => Ok(PySubqueryAlias { subquery_alias }), _ => Err(py_type_err("unexpected plan")), } } } ================================================ FILE: src/sql/logical/table_scan.rs ================================================ use std::{sync::Arc, vec}; use datafusion_python::{ datafusion_common::{DFSchema, ScalarValue}, datafusion_expr::{ expr::{Alias, InList}, logical_plan::TableScan, Expr, LogicalPlan, }, }; use pyo3::prelude::*; use crate::{ error::DaskPlannerError, expression::{py_expr_list, PyExpr}, sql::exceptions::py_type_err, }; #[pyclass(name = "TableScan", module = "dask_sql", subclass)] #[derive(Clone)] pub struct PyTableScan { pub(crate) table_scan: TableScan, input: Arc, } type FilterTuple = (String, String, Option>); #[pyclass(name = "FilteredResult", module = "dask_sql", subclass)] #[derive(Debug, Clone)] pub struct PyFilteredResult { // Certain Expr(s) do not have supporting logic in pyarrow for IO filtering // at read time. Those Expr(s) cannot be ignored however. This field stores // those Expr(s) so that they can be used on the Python side to create // Dask operations that handle that filtering as an extra task in the graph. #[pyo3(get)] pub io_unfilterable_exprs: Vec, // Expr(s) that can have their filtering logic performed in the pyarrow IO logic // are stored here in a DNF format that is expected by pyarrow. #[pyo3(get)] pub filtered_exprs: Vec<(PyExpr, FilterTuple)>, } impl PyTableScan { /// Ensures that a valid Expr variant type is present fn _valid_expr_type(expr: &[Expr]) -> bool { expr.iter() .all(|f| matches!(f, Expr::Column(_) | Expr::Literal(_))) } /// Transform the singular Expr instance into its DNF form serialized in a Vec instance. Possibly recursively expanding /// it as well if needed. pub fn _expand_dnf_filter( filter: &Expr, input: &Arc, py: Python, ) -> Result, DaskPlannerError> { let mut filter_tuple: Vec<(PyExpr, FilterTuple)> = Vec::new(); match filter { Expr::InList(InList { expr, list, negated, }) => { // Only handle simple Expr(s) for InList operations for now if PyTableScan::_valid_expr_type(list) { // While ANSI SQL would not allow for anything other than a Column or Literal // value in this "identifying" `expr` we explicitly check that here just to be sure. // IF it is something else it is returned to Dask to handle let ident = match *expr.clone() { Expr::Column(col) => Ok(col.name), Expr::Alias(Alias { name, .. }) => Ok(name), Expr::Literal(val) => Ok(format!("{}", val)), _ => Err(DaskPlannerError::InvalidIOFilter(format!( "Invalid InList Expr type `{}`. using in Dask instead", filter ))), }; let op = if *negated { "not in" } else { "in" }; let il: Result, DaskPlannerError> = list .iter() .map(|f| match f { Expr::Column(col) => Ok(col.name.clone().into_py(py)), Expr::Alias(Alias { name, ..}) => Ok(name.clone().into_py(py)), Expr::Literal(val) => match val { ScalarValue::Boolean(val) => Ok(val.unwrap().into_py(py)), ScalarValue::Float32(val) => Ok(val.unwrap().into_py(py)), ScalarValue::Float64(val) => Ok(val.unwrap().into_py(py)), ScalarValue::Int8(val) => Ok(val.unwrap().into_py(py)), ScalarValue::Int16(val) => Ok(val.unwrap().into_py(py)), ScalarValue::Int32(val) => Ok(val.unwrap().into_py(py)), ScalarValue::Int64(val) => Ok(val.unwrap().into_py(py)), ScalarValue::UInt8(val) => Ok(val.unwrap().into_py(py)), ScalarValue::UInt16(val) => Ok(val.unwrap().into_py(py)), ScalarValue::UInt32(val) => Ok(val.unwrap().into_py(py)), ScalarValue::UInt64(val) => Ok(val.unwrap().into_py(py)), ScalarValue::Utf8(val) => Ok(val.clone().unwrap().into_py(py)), ScalarValue::LargeUtf8(val) => Ok(val.clone().unwrap().into_py(py)), _ => Err(DaskPlannerError::InvalidIOFilter(format!( "Unsupported ScalarValue `{}` encountered. using in Dask instead", filter ))), }, _ => Ok(f.canonical_name().into_py(py)), }) .collect(); filter_tuple.push(( PyExpr::from(filter.clone(), Some(vec![input.clone()])), ( ident.unwrap_or(expr.canonical_name()), op.to_string(), Some(il?), ), )); Ok(filter_tuple) } else { let er = DaskPlannerError::InvalidIOFilter(format!( "Invalid identifying column Expr instance `{}`. using in Dask instead", filter )); Err::, DaskPlannerError>(er) } } Expr::IsNotNull(expr) => { // Only handle simple Expr(s) for IsNotNull operations for now let ident = match *expr.clone() { Expr::Column(col) => Ok(col.name), _ => Err(DaskPlannerError::InvalidIOFilter(format!( "Invalid IsNotNull Expr type `{}`. using in Dask instead", filter ))), }; filter_tuple.push(( PyExpr::from(filter.clone(), Some(vec![input.clone()])), ( ident.unwrap_or(expr.canonical_name()), "is not".to_string(), None, ), )); Ok(filter_tuple) } _ => { let er = DaskPlannerError::InvalidIOFilter(format!( "Unable to apply filter: `{}` to IO reader, using in Dask instead", filter )); Err::, DaskPlannerError>(er) } } } /// Consume the `TableScan` filters (Expr(s)) and convert them into a PyArrow understandable /// DNF format that can be directly passed to PyArrow IO readers for Predicate Pushdown. Expr(s) /// that cannot be converted to correlating PyArrow IO calls will be returned as is and can be /// used in the Python logic to form Dask tasks for the graph to do computational filtering. pub fn _expand_dnf_filters( input: &Arc, filters: &[Expr], py: Python, ) -> PyFilteredResult { let mut filtered_exprs: Vec<(PyExpr, FilterTuple)> = Vec::new(); let mut unfiltered_exprs: Vec = Vec::new(); filters .iter() .for_each(|f| match PyTableScan::_expand_dnf_filter(f, input, py) { Ok(mut expanded_dnf_filter) => filtered_exprs.append(&mut expanded_dnf_filter), Err(_e) => { unfiltered_exprs.push(PyExpr::from(f.clone(), Some(vec![input.clone()]))) } }); PyFilteredResult { io_unfilterable_exprs: unfiltered_exprs, filtered_exprs, } } } #[pymethods] impl PyTableScan { #[pyo3(name = "getTableScanProjects")] fn scan_projects(&mut self) -> PyResult> { match &self.table_scan.projection { Some(indices) => { let schema = self.table_scan.source.schema(); Ok(indices .iter() .map(|i| schema.field(*i).name().to_string()) .collect()) } None => Ok(vec![]), } } /// If the 'TableScan' contains columns that should be projected during the /// read return True, otherwise return False #[pyo3(name = "containsProjections")] fn contains_projections(&self) -> bool { self.table_scan.projection.is_some() } #[pyo3(name = "getFilters")] fn scan_filters(&self) -> PyResult> { py_expr_list(&self.input, &self.table_scan.filters) } #[pyo3(name = "getDNFFilters")] fn dnf_io_filters(&self, py: Python) -> PyResult { let results = PyTableScan::_expand_dnf_filters(&self.input, &self.table_scan.filters, py); Ok(results) } } impl TryFrom for PyTableScan { type Error = PyErr; fn try_from(logical_plan: LogicalPlan) -> Result { match logical_plan { LogicalPlan::TableScan(table_scan) => { // Create an input logical plan that's identical to the table scan with schema from the table source let mut input = table_scan.clone(); input.projected_schema = DFSchema::try_from_qualified_schema( &table_scan.table_name, &table_scan.source.schema(), ) .map_or(input.projected_schema, Arc::new); Ok(PyTableScan { table_scan, input: Arc::new(LogicalPlan::TableScan(input)), }) } _ => Err(py_type_err("unexpected plan")), } } } ================================================ FILE: src/sql/logical/use_schema.rs ================================================ use std::{ any::Any, fmt, hash::{Hash, Hasher}, sync::Arc, }; use datafusion_python::{ datafusion_common::{DFSchema, DFSchemaRef}, datafusion_expr::{logical_plan::UserDefinedLogicalNode, Expr, LogicalPlan}, }; use fmt::Debug; use pyo3::prelude::*; use crate::sql::{exceptions::py_type_err, logical}; #[derive(Clone, PartialEq)] pub struct UseSchemaPlanNode { pub schema: DFSchemaRef, pub schema_name: String, } impl Debug for UseSchemaPlanNode { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.fmt_for_explain(f) } } impl Hash for UseSchemaPlanNode { fn hash(&self, state: &mut H) { self.schema.hash(state); self.schema_name.hash(state); } } impl UserDefinedLogicalNode for UseSchemaPlanNode { fn as_any(&self) -> &dyn Any { self } fn inputs(&self) -> Vec<&LogicalPlan> { vec![] } fn schema(&self) -> &DFSchemaRef { &self.schema } fn expressions(&self) -> Vec { // there is no need to expose any expressions here since DataFusion would // not be able to do anything with expressions that are specific to // USE SCHEMA vec![] } fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "UseSchema: schema_name={}", self.schema_name) } fn from_template( &self, _exprs: &[Expr], _inputs: &[LogicalPlan], ) -> Arc { Arc::new(UseSchemaPlanNode { schema: Arc::new(DFSchema::empty()), schema_name: self.schema_name.clone(), }) } fn name(&self) -> &str { "UseSchema" } fn dyn_hash(&self, state: &mut dyn Hasher) { let mut s = state; self.hash(&mut s); } fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool { match other.as_any().downcast_ref::() { Some(o) => self == o, None => false, } } } #[pyclass(name = "UseSchema", module = "dask_sql", subclass)] pub struct PyUseSchema { pub(crate) use_schema: UseSchemaPlanNode, } #[pymethods] impl PyUseSchema { #[pyo3(name = "getSchemaName")] fn get_schema_name(&self) -> PyResult { Ok(self.use_schema.schema_name.clone()) } } impl TryFrom for PyUseSchema { type Error = PyErr; fn try_from(logical_plan: logical::LogicalPlan) -> Result { match logical_plan { logical::LogicalPlan::Extension(extension) => { if let Some(ext) = extension.node.as_any().downcast_ref::() { Ok(PyUseSchema { use_schema: ext.clone(), }) } else { Err(py_type_err("unexpected plan")) } } _ => Err(py_type_err("unexpected plan")), } } } ================================================ FILE: src/sql/logical/window.rs ================================================ use datafusion_python::{ datafusion_common::ScalarValue, datafusion_expr::{ expr::WindowFunction, logical_plan::Window, Expr, LogicalPlan, WindowFrame, WindowFrameBound, }, }; use pyo3::prelude::*; use crate::{ error::DaskPlannerError, expression::{py_expr_list, PyExpr}, sql::exceptions::py_type_err, }; #[pyclass(name = "Window", module = "dask_sql", subclass)] #[derive(Clone)] pub struct PyWindow { window: Window, } #[pyclass(name = "WindowFrame", module = "dask_sql", subclass)] #[derive(Clone)] pub struct PyWindowFrame { window_frame: WindowFrame, } #[pyclass(name = "WindowFrameBound", module = "dask_sql", subclass)] #[derive(Clone)] pub struct PyWindowFrameBound { frame_bound: WindowFrameBound, } impl TryFrom for PyWindow { type Error = PyErr; fn try_from(logical_plan: LogicalPlan) -> Result { match logical_plan { LogicalPlan::Window(window) => Ok(PyWindow { window }), _ => Err(py_type_err("unexpected plan")), } } } impl From for PyWindowFrame { fn from(window_frame: WindowFrame) -> Self { PyWindowFrame { window_frame } } } impl From for PyWindowFrameBound { fn from(frame_bound: WindowFrameBound) -> Self { PyWindowFrameBound { frame_bound } } } #[pymethods] impl PyWindow { /// Returns window expressions #[pyo3(name = "getGroups")] pub fn get_window_expr(&self) -> PyResult> { py_expr_list(&self.window.input, &self.window.window_expr) } /// Returns order by columns in a window function expression #[pyo3(name = "getSortExprs")] pub fn get_sort_exprs(&self, expr: PyExpr) -> PyResult> { match expr.expr.unalias() { Expr::WindowFunction(WindowFunction { order_by, .. }) => { py_expr_list(&self.window.input, &order_by) } other => Err(not_window_function_err(other)), } } /// Return partition by columns in a window function expression #[pyo3(name = "getPartitionExprs")] pub fn get_partition_exprs(&self, expr: PyExpr) -> PyResult> { match expr.expr.unalias() { Expr::WindowFunction(WindowFunction { partition_by, .. }) => { py_expr_list(&self.window.input, &partition_by) } other => Err(not_window_function_err(other)), } } /// Return input args for window function #[pyo3(name = "getArgs")] pub fn get_args(&self, expr: PyExpr) -> PyResult> { match expr.expr.unalias() { Expr::WindowFunction(WindowFunction { args, .. }) => { py_expr_list(&self.window.input, &args) } other => Err(not_window_function_err(other)), } } /// Return window function name #[pyo3(name = "getWindowFuncName")] pub fn window_func_name(&self, expr: PyExpr) -> PyResult { match expr.expr.unalias() { Expr::WindowFunction(WindowFunction { fun, .. }) => Ok(fun.to_string()), other => Err(not_window_function_err(other)), } } /// Returns a Pywindow frame for a given window function expression #[pyo3(name = "getWindowFrame")] pub fn get_window_frame(&self, expr: PyExpr) -> Option { match expr.expr.unalias() { Expr::WindowFunction(WindowFunction { window_frame, .. }) => Some(window_frame.into()), _ => None, } } } fn not_window_function_err(expr: Expr) -> PyErr { py_type_err(format!( "Provided {} Expr {:?} is not a WindowFunction type", expr.variant_name(), expr )) } #[pymethods] impl PyWindowFrame { /// Returns the window frame units for the bounds #[pyo3(name = "getFrameUnit")] pub fn get_frame_units(&self) -> PyResult { Ok(self.window_frame.units.to_string()) } /// Returns starting bound #[pyo3(name = "getLowerBound")] pub fn get_lower_bound(&self) -> PyResult { Ok(self.window_frame.start_bound.clone().into()) } /// Returns end bound #[pyo3(name = "getUpperBound")] pub fn get_upper_bound(&self) -> PyResult { Ok(self.window_frame.end_bound.clone().into()) } } #[pymethods] impl PyWindowFrameBound { /// Returns if the frame bound is current row #[pyo3(name = "isCurrentRow")] pub fn is_current_row(&self) -> bool { matches!(self.frame_bound, WindowFrameBound::CurrentRow) } /// Returns if the frame bound is preceding #[pyo3(name = "isPreceding")] pub fn is_preceding(&self) -> bool { matches!(self.frame_bound, WindowFrameBound::Preceding(_)) } /// Returns if the frame bound is following #[pyo3(name = "isFollowing")] pub fn is_following(&self) -> bool { matches!(self.frame_bound, WindowFrameBound::Following(_)) } /// Returns the offset of the window frame #[pyo3(name = "getOffset")] pub fn get_offset(&self) -> PyResult> { match &self.frame_bound { WindowFrameBound::Preceding(val) | WindowFrameBound::Following(val) => match val { x if x.is_null() => Ok(None), ScalarValue::UInt64(v) => Ok(*v), // The cast below is only safe because window bounds cannot be negative ScalarValue::Int64(v) => Ok(v.map(|n| n as u64)), ScalarValue::Utf8(v) => { let s = v.clone().unwrap(); match s.parse::() { Ok(s) => Ok(Some(s)), Err(_e) => Err(DaskPlannerError::Internal(format!( "Unable to parse u64 from Utf8 value '{s}'" )) .into()), } } ref x => Err(DaskPlannerError::Internal(format!( "Unexpected window frame bound: {x}" )) .into()), }, WindowFrameBound::CurrentRow => Ok(None), } } /// Returns if the frame bound is unbounded #[pyo3(name = "isUnbounded")] pub fn is_unbounded(&self) -> PyResult { match &self.frame_bound { WindowFrameBound::Preceding(v) | WindowFrameBound::Following(v) => Ok(v.is_null()), WindowFrameBound::CurrentRow => Ok(false), } } } ================================================ FILE: src/sql/logical.rs ================================================ use crate::sql::{ table, types::{rel_data_type::RelDataType, rel_data_type_field::RelDataTypeField}, }; pub mod aggregate; pub mod alter_schema; pub mod alter_table; pub mod analyze_table; pub mod create_catalog_schema; pub mod create_experiment; pub mod create_memory_table; pub mod create_model; pub mod create_table; pub mod describe_model; pub mod drop_model; pub mod drop_schema; pub mod drop_table; pub mod empty_relation; pub mod explain; pub mod export_model; pub mod filter; pub mod join; pub mod limit; pub mod predict_model; pub mod projection; pub mod repartition_by; pub mod show_columns; pub mod show_models; pub mod show_schemas; pub mod show_tables; pub mod sort; pub mod subquery_alias; pub mod table_scan; pub mod use_schema; pub mod window; use datafusion_python::{ datafusion_common::{DFSchemaRef, DataFusionError}, datafusion_expr::{DdlStatement, LogicalPlan}, }; use pyo3::prelude::*; use self::{ alter_schema::AlterSchemaPlanNode, alter_table::AlterTablePlanNode, analyze_table::AnalyzeTablePlanNode, create_catalog_schema::CreateCatalogSchemaPlanNode, create_experiment::CreateExperimentPlanNode, create_model::CreateModelPlanNode, create_table::CreateTablePlanNode, describe_model::DescribeModelPlanNode, drop_model::DropModelPlanNode, drop_schema::DropSchemaPlanNode, export_model::ExportModelPlanNode, predict_model::PredictModelPlanNode, show_columns::ShowColumnsPlanNode, show_models::ShowModelsPlanNode, show_schemas::ShowSchemasPlanNode, show_tables::ShowTablesPlanNode, use_schema::UseSchemaPlanNode, }; use crate::{error::Result, sql::exceptions::py_type_err}; #[pyclass(name = "LogicalPlan", module = "dask_sql", subclass)] #[derive(Debug, Clone)] pub struct PyLogicalPlan { /// The original LogicalPlan that was parsed by DataFusion from the input SQL pub(crate) original_plan: LogicalPlan, /// The original_plan is traversed. current_node stores the current node of this traversal pub(crate) current_node: Option, } /// Unfortunately PyO3 forces us to do this as placing these methods in the #[pymethods] version /// of `impl PyLogicalPlan` causes issues with types not properly being mapped to Python from Rust impl PyLogicalPlan { /// Getter method for the LogicalPlan, if current_node is None return original_plan. pub(crate) fn current_node(&mut self) -> LogicalPlan { match &self.current_node { Some(current) => current.clone(), None => { self.current_node = Some(self.original_plan.clone()); self.current_node.clone().unwrap() } } } } /// Convert a LogicalPlan to a Python equivalent type fn to_py_plan>( current_node: Option<&LogicalPlan>, ) -> PyResult { match current_node { Some(plan) => plan.clone().try_into(), _ => Err(py_type_err("current_node was None")), } } #[pymethods] impl PyLogicalPlan { /// LogicalPlan::Aggregate as PyAggregate pub fn aggregate(&self) -> PyResult { to_py_plan(self.current_node.as_ref()) } /// LogicalPlan::EmptyRelation as PyEmptyRelation pub fn empty_relation(&self) -> PyResult { to_py_plan(self.current_node.as_ref()) } /// LogicalPlan::Explain as PyExplain pub fn explain(&self) -> PyResult { to_py_plan(self.current_node.as_ref()) } /// LogicalPlan::Filter as PyFilter pub fn filter(&self) -> PyResult { to_py_plan(self.current_node.as_ref()) } /// LogicalPlan::Join as PyJoin pub fn join(&self) -> PyResult { to_py_plan(self.current_node.as_ref()) } /// LogicalPlan::Limit as PyLimit pub fn limit(&self) -> PyResult { to_py_plan(self.current_node.as_ref()) } /// LogicalPlan::Projection as PyProjection pub fn projection(&self) -> PyResult { to_py_plan(self.current_node.as_ref()) } /// LogicalPlan::Sort as PySort pub fn sort(&self) -> PyResult { to_py_plan(self.current_node.as_ref()) } /// LogicalPlan::SubqueryAlias as PySubqueryAlias pub fn subquery_alias(&self) -> PyResult { to_py_plan(self.current_node.as_ref()) } /// LogicalPlan::Window as PyWindow pub fn window(&self) -> PyResult { to_py_plan(self.current_node.as_ref()) } /// LogicalPlan::TableScan as PyTableScan pub fn table_scan(&self) -> PyResult { to_py_plan(self.current_node.as_ref()) } /// LogicalPlan::CreateMemoryTable as PyCreateMemoryTable pub fn create_memory_table(&self) -> PyResult { to_py_plan(self.current_node.as_ref()) } /// LogicalPlan::CreateModel as PyCreateModel pub fn create_model(&self) -> PyResult { to_py_plan(self.current_node.as_ref()) } /// LogicalPlan::CreateExperiment as PyCreateExperiment pub fn create_experiment(&self) -> PyResult { to_py_plan(self.current_node.as_ref()) } /// LogicalPlan::DropTable as DropTable pub fn drop_table(&self) -> PyResult { to_py_plan(self.current_node.as_ref()) } /// LogicalPlan::DropModel as DropModel pub fn drop_model(&self) -> PyResult { to_py_plan(self.current_node.as_ref()) } /// LogicalPlan::Extension::ShowSchemas as PyShowSchemas pub fn show_schemas(&self) -> PyResult { to_py_plan(self.current_node.as_ref()) } /// LogicalPlan::Repartition as PyRepartitionBy pub fn repartition_by(&self) -> PyResult { to_py_plan(self.current_node.as_ref()) } /// LogicalPlan::Extension::ShowTables as PyShowTables pub fn show_tables(&self) -> PyResult { to_py_plan(self.current_node.as_ref()) } /// LogicalPlan::Extension::CreateTable as PyCreateTable pub fn create_table(&self) -> PyResult { to_py_plan(self.current_node.as_ref()) } /// LogicalPlan::Extension::PredictModel as PyPredictModel pub fn predict_model(&self) -> PyResult { to_py_plan(self.current_node.as_ref()) } /// LogicalPlan::Extension::DescribeModel as PyDescribeModel pub fn describe_model(&self) -> PyResult { to_py_plan(self.current_node.as_ref()) } /// LogicalPlan::Extension::ExportModel as PyExportModel pub fn export_model(&self) -> PyResult { to_py_plan(self.current_node.as_ref()) } /// LogicalPlan::Extension::ShowColumns as PyShowColumns pub fn show_columns(&self) -> PyResult { to_py_plan(self.current_node.as_ref()) } pub fn show_models(&self) -> PyResult { to_py_plan(self.current_node.as_ref()) } /// LogicalPlan::Extension::ShowColumns as PyShowColumns pub fn analyze_table(&self) -> PyResult { to_py_plan(self.current_node.as_ref()) } /// LogicalPlan::CreateCatalogSchema as PyCreateCatalogSchema pub fn create_catalog_schema(&self) -> PyResult { to_py_plan(self.current_node.as_ref()) } /// LogicalPlan::Extension::DropSchema as PyDropSchema pub fn drop_schema(&self) -> PyResult { to_py_plan(self.current_node.as_ref()) } /// LogicalPlan::Extension::UseSchema as PyUseSchema pub fn use_schema(&self) -> PyResult { to_py_plan(self.current_node.as_ref()) } /// LogicalPlan::Extension::AlterTable as PyAlterTable pub fn alter_table(&self) -> PyResult { to_py_plan(self.current_node.as_ref()) } /// LogicalPlan::Extension::AlterSchema as PyAlterSchema pub fn alter_schema(&self) -> PyResult { to_py_plan(self.current_node.as_ref()) } /// Gets the "input" for the current LogicalPlan pub fn get_inputs(&mut self) -> PyResult> { let mut py_inputs: Vec = Vec::new(); for input in self.current_node().inputs() { py_inputs.push(input.clone().into()); } Ok(py_inputs) } /// If the LogicalPlan represents access to a Table that instance is returned /// otherwise None is returned #[pyo3(name = "getTable")] pub fn table(&mut self) -> PyResult { match table::table_from_logical_plan(&self.current_node())? { Some(table) => Ok(table), None => Err(py_type_err( "Unable to compute DaskTable from DataFusion LogicalPlan", )), } } #[pyo3(name = "getCurrentNodeSchemaName")] pub fn get_current_node_schema_name(&self) -> PyResult<&str> { match &self.current_node { Some(e) => { let _sch: &DFSchemaRef = e.schema(); //TODO: Where can I actually get this in the context of the running query? Ok("root") } None => Err(py_type_err(DataFusionError::Plan(format!( "Current schema not found. Defaulting to {:?}", "root" )))), } } #[pyo3(name = "getCurrentNodeTableName")] pub fn get_current_node_table_name(&mut self) -> PyResult { match self.table() { Ok(dask_table) => Ok(dask_table.table_name), Err(_e) => Err(py_type_err("Unable to determine current node table name")), } } /// Gets the Relation "type" of the current node. Ex: Projection, TableScan, etc pub fn get_current_node_type(&mut self) -> PyResult<&str> { Ok(match self.current_node() { LogicalPlan::Dml(_) => "DataManipulationLanguage", LogicalPlan::DescribeTable(_) => "DescribeTable", LogicalPlan::Prepare(_) => "Prepare", LogicalPlan::Distinct(_) => "Distinct", LogicalPlan::Projection(_projection) => "Projection", LogicalPlan::Filter(_filter) => "Filter", LogicalPlan::Window(_window) => "Window", LogicalPlan::Aggregate(_aggregate) => "Aggregate", LogicalPlan::Sort(_sort) => "Sort", LogicalPlan::Join(_join) => "Join", LogicalPlan::CrossJoin(_cross_join) => "CrossJoin", LogicalPlan::Repartition(_repartition) => "Repartition", LogicalPlan::Union(_union) => "Union", LogicalPlan::TableScan(_table_scan) => "TableScan", LogicalPlan::EmptyRelation(_empty_relation) => "EmptyRelation", LogicalPlan::Limit(_limit) => "Limit", LogicalPlan::Ddl(DdlStatement::CreateExternalTable { .. }) => "CreateExternalTable", LogicalPlan::Ddl(DdlStatement::CreateMemoryTable { .. }) => "CreateMemoryTable", LogicalPlan::Ddl(DdlStatement::DropTable { .. }) => "DropTable", LogicalPlan::Ddl(DdlStatement::DropView { .. }) => "DropView", LogicalPlan::Values(_values) => "Values", LogicalPlan::Explain(_explain) => "Explain", LogicalPlan::Analyze(_analyze) => "Analyze", LogicalPlan::Subquery(_sub_query) => "Subquery", LogicalPlan::SubqueryAlias(_sqalias) => "SubqueryAlias", LogicalPlan::Ddl(DdlStatement::CreateCatalogSchema { .. }) => "CreateCatalogSchema", LogicalPlan::Ddl(DdlStatement::DropCatalogSchema { .. }) => "DropCatalogSchema", LogicalPlan::Ddl(DdlStatement::CreateCatalog { .. }) => "CreateCatalog", LogicalPlan::Ddl(DdlStatement::CreateView { .. }) => "CreateView", LogicalPlan::Statement(_) => "Statement", // Further examine and return the name that is a possible Dask-SQL Extension type LogicalPlan::Extension(extension) => { let node = extension.node.as_any(); if node.downcast_ref::().is_some() { "CreateModel" } else if node.downcast_ref::().is_some() { "CreateExperiment" } else if node.downcast_ref::().is_some() { "CreateCatalogSchema" } else if node.downcast_ref::().is_some() { "CreateTable" } else if node.downcast_ref::().is_some() { "DropModel" } else if node.downcast_ref::().is_some() { "PredictModel" } else if node.downcast_ref::().is_some() { "ExportModel" } else if node.downcast_ref::().is_some() { "DescribeModel" } else if node.downcast_ref::().is_some() { "ShowSchemas" } else if node.downcast_ref::().is_some() { "ShowTables" } else if node.downcast_ref::().is_some() { "ShowColumns" } else if node.downcast_ref::().is_some() { "ShowModels" } else if node.downcast_ref::().is_some() { "DropSchema" } else if node.downcast_ref::().is_some() { "UseSchema" } else if node.downcast_ref::().is_some() { "AnalyzeTable" } else if node.downcast_ref::().is_some() { "AlterTable" } else if node.downcast_ref::().is_some() { "AlterSchema" } else { // Default to generic `Extension` "Extension" } } LogicalPlan::Unnest(_unnest) => "Unnest", LogicalPlan::Copy(_) => "Copy", }) } /// Explain plan for the full and original LogicalPlan pub fn explain_original(&self) -> PyResult { Ok(format!("{}", self.original_plan.display_indent())) } /// Explain plan from the current node onward pub fn explain_current(&mut self) -> PyResult { Ok(format!("{}", self.current_node().display_indent())) } #[pyo3(name = "getRowType")] pub fn row_type(&self) -> PyResult { match &self.original_plan { LogicalPlan::Join(join) => { let mut lhs_fields: Vec = join .left .schema() .fields() .iter() .map(|f| RelDataTypeField::from(f, join.left.schema().as_ref())) .collect::>>() .map_err(py_type_err)?; let mut rhs_fields: Vec = join .right .schema() .fields() .iter() .map(|f| RelDataTypeField::from(f, join.right.schema().as_ref())) .collect::>>() .map_err(py_type_err)?; lhs_fields.append(&mut rhs_fields); Ok(RelDataType::new(false, lhs_fields)) } LogicalPlan::Distinct(distinct) => { let schema = distinct.input.schema(); let rel_fields: Vec = schema .fields() .iter() .map(|f| RelDataTypeField::from(f, schema.as_ref())) .collect::>>() .map_err(py_type_err)?; Ok(RelDataType::new(false, rel_fields)) } _ => { let schema = self.original_plan.schema(); let rel_fields: Vec = schema .fields() .iter() .map(|f| RelDataTypeField::from(f, schema.as_ref())) .collect::>>() .map_err(py_type_err)?; Ok(RelDataType::new(false, rel_fields)) } } } } impl From for LogicalPlan { fn from(logical_plan: PyLogicalPlan) -> LogicalPlan { logical_plan.original_plan } } impl From for PyLogicalPlan { fn from(logical_plan: LogicalPlan) -> PyLogicalPlan { PyLogicalPlan { original_plan: logical_plan, current_node: None, } } } ================================================ FILE: src/sql/optimizer/decorrelate_where_exists.rs ================================================ // Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. use std::sync::Arc; use datafusion_python::{ datafusion_common::{Column, DataFusionError, Result}, datafusion_expr::{ expr::Exists, logical_plan::{Distinct, Filter, JoinType, Subquery}, Expr, LogicalPlan, LogicalPlanBuilder, }, datafusion_optimizer::optimizer::{ApplyOrder, OptimizerConfig, OptimizerRule}, }; use crate::sql::optimizer::utils::{ collect_subquery_cols, conjunction, extract_join_filters, split_conjunction, }; /// Optimizer rule for rewriting subquery filters to joins #[derive(Default)] pub struct DecorrelateWhereExists {} impl DecorrelateWhereExists { #[allow(missing_docs)] pub fn new() -> Self { Self {} } /// Finds expressions that have a where in subquery (and recurse when found) /// /// # Arguments /// /// * `predicate` - A conjunction to split and search /// * `optimizer_config` - For generating unique subquery aliases /// /// Returns a tuple (subqueries, non-subquery expressions) fn extract_subquery_exprs( &self, predicate: &Expr, config: &dyn OptimizerConfig, ) -> Result<(Vec, Vec)> { let filters = split_conjunction(predicate); let mut subqueries = vec![]; let mut others = vec![]; for it in filters.iter() { match it { Expr::Exists(Exists { subquery, negated }) => { let subquery_plan = self .try_optimize(&subquery.subquery, config)? .map(Arc::new) .unwrap_or_else(|| subquery.subquery.clone()); let new_subquery = subquery.with_plan(subquery_plan); subqueries.push(SubqueryInfo::new(new_subquery, *negated)); } _ => others.push((*it).clone()), } } Ok((subqueries, others)) } } impl OptimizerRule for DecorrelateWhereExists { fn try_optimize( &self, plan: &LogicalPlan, config: &dyn OptimizerConfig, ) -> Result> { match plan { LogicalPlan::Filter(filter) => { let (subqueries, other_exprs) = self.extract_subquery_exprs(&filter.predicate, config)?; if subqueries.is_empty() { // regular filter, no subquery exists clause here return Ok(None); } // iterate through all exists clauses in predicate, turning each into a join let mut cur_input = filter.input.as_ref().clone(); for subquery in subqueries { if let Some(x) = optimize_exists(&subquery, &cur_input)? { cur_input = x; } else { return Ok(None); } } let expr = conjunction(other_exprs); if let Some(expr) = expr { let new_filter = Filter::try_new(expr, Arc::new(cur_input))?; cur_input = LogicalPlan::Filter(new_filter); } Ok(Some(cur_input)) } _ => Ok(None), } } fn name(&self) -> &str { "decorrelate_where_exists" } fn apply_order(&self) -> Option { Some(ApplyOrder::TopDown) } } /// Takes a query like: /// /// SELECT t1.id /// FROM t1 /// WHERE exists /// ( /// SELECT t2.id FROM t2 WHERE t1.id = t2.id /// ) /// /// and optimizes it into: /// /// SELECT t1.id /// FROM t1 LEFT SEMI /// JOIN t2 /// ON t1.id = t2.id /// /// # Arguments /// /// * query_info - The subquery and negated(exists/not exists) info. /// * outer_input - The non-subquery portion (relation t1) fn optimize_exists( query_info: &SubqueryInfo, outer_input: &LogicalPlan, ) -> Result> { let subquery = query_info.query.subquery.as_ref(); if let Some((join_filter, optimized_subquery)) = optimize_subquery(subquery)? { // join our sub query into the main plan let join_type = match query_info.negated { true => JoinType::LeftAnti, false => JoinType::LeftSemi, }; let new_plan = LogicalPlanBuilder::from(outer_input.clone()) .join( optimized_subquery, join_type, (Vec::::new(), Vec::::new()), Some(join_filter), )? .build()?; Ok(Some(new_plan)) } else { Ok(None) } } /// Optimize the subquery and extract the possible join filter. /// This function can't optimize non-correlated subquery, and will return None. fn optimize_subquery(subquery: &LogicalPlan) -> Result> { match subquery { LogicalPlan::Distinct(subqry_distinct) => { let distinct_input = &subqry_distinct.input; let optimized_plan = optimize_subquery(distinct_input)?.map(|(filters, right)| { ( filters, LogicalPlan::Distinct(Distinct { input: Arc::new(right), }), ) }); Ok(optimized_plan) } LogicalPlan::Projection(projection) => { // extract join filters let (join_filters, subquery_input) = extract_join_filters(&projection.input)?; // cannot optimize non-correlated subquery if join_filters.is_empty() { return Ok(None); } let input_schema = subquery_input.schema(); let project_exprs: Vec = collect_subquery_cols(&join_filters, input_schema.clone())? .into_iter() .map(Expr::Column) .collect(); let right = LogicalPlanBuilder::from(subquery_input) .project(project_exprs)? .build()?; // join_filters is not empty. let join_filter = conjunction(join_filters).ok_or_else(|| { DataFusionError::Internal("join filters should not be empty".to_string()) })?; Ok(Some((join_filter, right))) } _ => Ok(None), } } struct SubqueryInfo { query: Subquery, negated: bool, } impl SubqueryInfo { pub fn new(query: Subquery, negated: bool) -> Self { Self { query, negated } } } ================================================ FILE: src/sql/optimizer/decorrelate_where_in.rs ================================================ // Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. use std::sync::Arc; use datafusion_python::{ datafusion_common::{alias::AliasGenerator, context, Column, DataFusionError, Result}, datafusion_expr::{ expr::InSubquery, expr_rewriter::unnormalize_col, logical_plan::{JoinType, Projection, Subquery}, Expr, Filter, LogicalPlan, LogicalPlanBuilder, }, datafusion_optimizer::optimizer::{ApplyOrder, OptimizerConfig, OptimizerRule}, }; use log::debug; use crate::sql::optimizer::utils::{ collect_subquery_cols, conjunction, extract_join_filters, only_or_err, replace_qualified_name, split_conjunction, }; #[derive(Default)] pub struct DecorrelateWhereIn { alias: AliasGenerator, } impl DecorrelateWhereIn { #[allow(missing_docs)] pub fn new() -> Self { Self::default() } /// Finds expressions that have a where in subquery (and recurses when found) /// /// # Arguments /// /// * `predicate` - A conjunction to split and search /// * `optimizer_config` - For generating unique subquery aliases /// /// Returns a tuple (subqueries, non-subquery expressions) fn extract_subquery_exprs( &self, predicate: &Expr, config: &dyn OptimizerConfig, ) -> Result<(Vec, Vec)> { let filters = split_conjunction(predicate); // TODO: disjunctions let mut subqueries = vec![]; let mut others = vec![]; for it in filters.iter() { match it { Expr::InSubquery(InSubquery { expr, subquery, negated, }) => { let subquery_plan = self .try_optimize(&subquery.subquery, config)? .map(Arc::new) .unwrap_or_else(|| subquery.subquery.clone()); let new_subquery = subquery.with_plan(subquery_plan); subqueries.push(SubqueryInfo::new(new_subquery, (**expr).clone(), *negated)); // TODO: if subquery doesn't get optimized, optimized children are lost } _ => others.push((*it).clone()), } } Ok((subqueries, others)) } } impl OptimizerRule for DecorrelateWhereIn { fn try_optimize( &self, plan: &LogicalPlan, config: &dyn OptimizerConfig, ) -> Result> { match plan { LogicalPlan::Filter(filter) => { let (subqueries, other_exprs) = self.extract_subquery_exprs(&filter.predicate, config)?; if subqueries.is_empty() { // regular filter, no subquery exists clause here return Ok(None); } // iterate through all exists clauses in predicate, turning each into a join let mut cur_input = filter.input.as_ref().clone(); for subquery in subqueries { cur_input = optimize_where_in(&subquery, &cur_input, &self.alias)?; } let expr = conjunction(other_exprs); if let Some(expr) = expr { let new_filter = Filter::try_new(expr, Arc::new(cur_input))?; cur_input = LogicalPlan::Filter(new_filter); } Ok(Some(cur_input)) } _ => Ok(None), } } fn name(&self) -> &str { "decorrelate_where_in" } fn apply_order(&self) -> Option { Some(ApplyOrder::TopDown) } } /// Optimize the where in subquery to left-anti/left-semi join. /// If the subquery is a correlated subquery, we need extract the join predicate from the subquery. /// /// For example, given a query like: /// `select t1.a, t1.b from t1 where t1 in (select t2.a from t2 where t1.b = t2.b and t1.c > t2.c)` /// /// The optimized plan will be: /// /// ```text /// Projection: t1.a, t1.b /// LeftSemi Join: Filter: t1.a = __correlated_sq_1.a AND t1.b = __correlated_sq_1.b AND t1.c > __correlated_sq_1.c /// TableScan: t1 /// SubqueryAlias: __correlated_sq_1 /// Projection: t2.a AS a, t2.b, t2.c /// TableScan: t2 /// ``` fn optimize_where_in( query_info: &SubqueryInfo, left: &LogicalPlan, alias: &AliasGenerator, ) -> Result { let projection = try_from_plan(&query_info.query.subquery) .map_err(|e| context!("a projection is required", e))?; let subquery_input = projection.input.clone(); // TODO add the validate logic to Analyzer let subquery_expr = only_or_err(projection.expr.as_slice()) .map_err(|e| context!("single expression projection required", e))?; // extract join filters let (join_filters, subquery_input) = extract_join_filters(subquery_input.as_ref())?; // in_predicate may be also include in the join filters, remove it from the join filters. let in_predicate = Expr::eq(query_info.where_in_expr.clone(), subquery_expr.clone()); let join_filters = remove_duplicated_filter(join_filters, in_predicate); // replace qualified name with subquery alias. let subquery_alias = alias.next("__correlated_sq"); let input_schema = subquery_input.schema(); let mut subquery_cols = collect_subquery_cols(&join_filters, input_schema.clone())?; let join_filter = conjunction(join_filters).map_or(Ok(None), |filter| { replace_qualified_name(filter, &subquery_cols, &subquery_alias).map(Option::Some) })?; // add projection if let Expr::Column(col) = subquery_expr { subquery_cols.remove(col); } let subquery_expr_name = format!("{:?}", unnormalize_col(subquery_expr.clone())); let first_expr = subquery_expr.clone().alias(subquery_expr_name.clone()); let projection_exprs: Vec = [first_expr] .into_iter() .chain(subquery_cols.into_iter().map(Expr::Column)) .collect(); let right = LogicalPlanBuilder::from(subquery_input) .project(projection_exprs)? .alias(subquery_alias.clone())? .build()?; // join our sub query into the main plan let join_type = match query_info.negated { true => JoinType::LeftAnti, false => JoinType::LeftSemi, }; let right_join_col = Column::new(Some(subquery_alias), subquery_expr_name); let in_predicate = Expr::eq( query_info.where_in_expr.clone(), Expr::Column(right_join_col), ); let join_filter = join_filter .map(|filter| in_predicate.clone().and(filter)) .unwrap_or_else(|| in_predicate); let new_plan = LogicalPlanBuilder::from(left.clone()) .join( right, join_type, (Vec::::new(), Vec::::new()), Some(join_filter), )? .build()?; debug!("where in optimized:\n{}", new_plan.display_indent()); Ok(new_plan) } fn remove_duplicated_filter(filters: Vec, in_predicate: Expr) -> Vec { filters .into_iter() .filter(|filter| { if filter == &in_predicate { return false; } // ignore the binary order !match (filter, &in_predicate) { (Expr::BinaryExpr(a_expr), Expr::BinaryExpr(b_expr)) => { (a_expr.op == b_expr.op) && (a_expr.left == b_expr.left && a_expr.right == b_expr.right) || (a_expr.left == b_expr.right && a_expr.right == b_expr.left) } _ => false, } }) .collect::>() } fn try_from_plan(plan: &LogicalPlan) -> Result<&Projection> { match plan { LogicalPlan::Projection(it) => Ok(it), _ => Err(DataFusionError::Internal( "Could not coerce into Projection!".to_string(), )), } } struct SubqueryInfo { query: Subquery, where_in_expr: Expr, negated: bool, } impl SubqueryInfo { pub fn new(query: Subquery, expr: Expr, negated: bool) -> Self { Self { query, where_in_expr: expr, negated, } } } ================================================ FILE: src/sql/optimizer/dynamic_partition_pruning.rs ================================================ //! Optimizer rule for dynamic partition pruning (DPP) //! //! DPP refers to a query optimization rule in which distinct values in an inner join are used as //! filters in a table scan. This allows us to eliminate all other rows which do not fit the join //! condition from being read at all. //! //! Furthermore, a table involved in a join may be filtered during a scan, which allows us to //! further prune the values to be read. use std::{ collections::{HashMap, HashSet}, fs, hash::{Hash, Hasher}, }; use datafusion_python::{ datafusion::parquet::{ basic::Type as BasicType, file::reader::{FileReader, SerializedFileReader}, record::{reader::RowIter, RowAccessor}, schema::{parser::parse_message_type, types::Type}, }, datafusion_common::{Column, Result, ScalarValue}, datafusion_expr::{ expr::InList, logical_plan::LogicalPlan, Expr, JoinType, Operator, TableScan, }, datafusion_optimizer::{OptimizerConfig, OptimizerRule}, }; use log::warn; use crate::sql::table::DaskTableSource; // Optimizer rule for dynamic partition pruning pub struct DynamicPartitionPruning { /// Ratio of the size of the dimension tables to fact tables fact_dimension_ratio: f64, } impl DynamicPartitionPruning { pub fn new(fact_dimension_ratio: f64) -> Self { Self { fact_dimension_ratio, } } } impl OptimizerRule for DynamicPartitionPruning { fn name(&self) -> &str { "dynamic_partition_pruning" } fn try_optimize( &self, plan: &LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { // Parse the LogicalPlan and store tables and columns being (inner) joined upon. We do this // by creating a HashSet of all InnerJoins' join.on and join.filters let join_conds = gather_joins(plan); let tables = gather_tables(plan); let aliases = gather_aliases(plan); if join_conds.is_empty() || tables.is_empty() { // No InnerJoins to optimize with Ok(None) } else { // Find the size of the largest table in the query let mut largest_size = 1_f64; for table in &tables { let table_size = table.1.size.unwrap_or(0) as f64; if table_size > largest_size { largest_size = table_size; } } let mut join_values = vec![]; let mut join_tables = vec![]; let mut join_fields = vec![]; let mut fact_tables = HashSet::new(); // Iterate through all inner joins in the query for join_cond in &join_conds { let join_on = &join_cond.on; for on_i in join_on { // Obtain tables and columns (fields) involved in join let (left_on, right_on) = (&on_i.0, &on_i.1); let (mut left_table, mut right_table) = (None, None); let (mut left_field, mut right_field) = (None, None); if let Expr::Column(c) = left_on { left_table = Some(c.relation.clone().unwrap().to_string().clone()); left_field = Some(c.name.clone()); } if let Expr::Column(c) = right_on { right_table = Some(c.relation.clone().unwrap().to_string().clone()); right_field = Some(c.name.clone()); } // For now, if it is not a join between columns then we skip the rule // TODO: https://github.com/dask-contrib/dask-sql/issues/1121 if left_table.is_none() || right_table.is_none() { continue; } let (mut left_table, mut right_table) = (left_table.unwrap(), right_table.unwrap()); let (left_field, right_field) = (left_field.unwrap(), right_field.unwrap()); let (mut left_filtered_table, mut right_filtered_table) = (None, None); // Check if join uses an alias instead of the table name itself. Need to use // the actual table name to obtain its filepath let left_alias = aliases.get(&left_table.clone()); if let Some(t) = left_alias { left_table = t.to_string() } let right_alias = aliases.get(&right_table.clone()); if let Some(t) = right_alias { right_table = t.to_string() } // A more complicated alias, e.g. an alias for a nested select, means it's not // obvious which file(s) should be read if !tables.contains_key(&left_table) || !tables.contains_key(&right_table) { continue; } // Determine whether a table is a fact or dimension table. If it's a dimension // table, we should read it in and use the rule if tables .get(&left_table.clone()) .unwrap() .size .unwrap_or(largest_size as usize) as f64 / largest_size < self.fact_dimension_ratio { left_filtered_table = read_table(left_table.clone(), left_field.clone(), tables.clone()); } else { fact_tables.insert(left_table.clone()); } if tables .get(&right_table.clone()) .unwrap() .size .unwrap_or(largest_size as usize) as f64 / largest_size < self.fact_dimension_ratio { right_filtered_table = read_table(right_table.clone(), right_field.clone(), tables.clone()); } else { fact_tables.insert(right_table.clone()); } join_values.push((left_filtered_table, right_filtered_table)); join_tables.push((left_table, right_table)); join_fields.push((left_field, right_field)); } } // Creates HashMap of all tables and field with their unique values to be set in the // TableScan let filter_values = combine_sets(join_values, join_tables, join_fields, fact_tables); // Optimize and return the plan optimize_table_scans(plan, filter_values) } } } /// Represents relevant information in an InnerJoin #[derive(Clone, Debug, Eq, Hash, PartialEq)] struct JoinInfo { /// Equijoin clause expressed as pairs of (left, right) join expressions on: Vec<(Expr, Expr)>, /// Filters applied during join (non-equi conditions) /// TODO: https://github.com/dask-contrib/dask-sql/issues/1121 filter: Option, } // This function parses through the LogicalPlan, grabs relevant information from an InnerJoin, and // adds them to a HashSet fn gather_joins(plan: &LogicalPlan) -> HashSet { let mut current_plan = plan.clone(); let mut join_info = HashSet::new(); loop { if current_plan.inputs().is_empty() { break; } else if current_plan.inputs().len() > 1 { match current_plan { LogicalPlan::Join(ref j) => { if j.join_type == JoinType::Inner { // Store tables and columns that are being (inner) joined upon let info = JoinInfo { on: j.on.clone(), filter: j.filter.clone(), }; join_info.insert(info); // Recurse on left and right inputs of Join let (left_joins, right_joins) = (gather_joins(&j.left), gather_joins(&j.right)); // Add left_joins and right_joins to HashSet join_info.extend(left_joins); join_info.extend(right_joins); } else { // We don't run the rule if there are non-inner joins in the query return HashSet::new(); } } LogicalPlan::CrossJoin(ref c) => { // Recurse on left and right inputs of CrossJoin let (left_joins, right_joins) = (gather_joins(&c.left), gather_joins(&c.right)); // Add left_joins and right_joins to HashSet join_info.extend(left_joins); join_info.extend(right_joins); } LogicalPlan::Union(ref u) => { // Recurse on inputs vector of Union for input in &u.inputs { let joins = gather_joins(input); // Add joins to HashSet join_info.extend(joins); } } _ => { warn!("Skipping optimizer rule 'DynamicPartitionPruning'"); return HashSet::new(); } } break; } else { // Move on to next step current_plan = current_plan.inputs()[0].clone(); } } join_info } /// Represents relevant information in a TableScan #[derive(Clone, Debug, Eq, Hash, PartialEq)] struct TableInfo { /// The name of the table table_name: String, /// The path and filename of the table filepath: String, /// The number of rows in the table size: Option, /// Optional expressions to be used as filters by the table provider filters: Vec, } // This function parses through the LogicalPlan, grabs relevant information from a TableScan, and // adds them to a HashMap where the key is the table name fn gather_tables(plan: &LogicalPlan) -> HashMap { let mut current_plan = plan.clone(); let mut tables = HashMap::new(); loop { if current_plan.inputs().is_empty() { if let LogicalPlan::TableScan(ref t) = current_plan { // Use TableScan to get the filepath and/or size let filepath = get_filepath(¤t_plan); let size = get_table_size(¤t_plan); match filepath { Some(f) => { // TODO: Add better handling for when a table is read in more than once // https://github.com/dask-contrib/dask-sql/issues/1121 if tables.contains_key(&t.table_name.to_string()) { return HashMap::new(); } tables.insert( t.table_name.to_string(), TableInfo { table_name: t.table_name.to_string(), filepath: f.clone(), size, filters: t.filters.clone(), }, ); break; } None => return HashMap::new(), } } break; } else if current_plan.inputs().len() > 1 { match current_plan { LogicalPlan::Join(ref j) => { // Recurse on left and right inputs of Join let (left_tables, right_tables) = (gather_tables(&j.left), gather_tables(&j.right)); if check_table_overlaps(&tables, &left_tables, &right_tables) { return HashMap::new(); } // Add left_tables and right_tables to HashMap tables.extend(left_tables); tables.extend(right_tables); } LogicalPlan::CrossJoin(ref c) => { // Recurse on left and right inputs of CrossJoin let (left_tables, right_tables) = (gather_tables(&c.left), gather_tables(&c.right)); if check_table_overlaps(&tables, &left_tables, &right_tables) { return HashMap::new(); } // Add left_tables and right_tables to HashMap tables.extend(left_tables); tables.extend(right_tables); } LogicalPlan::Union(ref u) => { // Recurse on inputs vector of Union for input in &u.inputs { let union_tables = gather_tables(input); // TODO: Add better handling for when a table is read in more than once // https://github.com/dask-contrib/dask-sql/issues/1121 if tables.keys().any(|k| union_tables.contains_key(k)) || union_tables.keys().any(|k| tables.contains_key(k)) { return HashMap::new(); } // Add union_tables to HashMap tables.extend(union_tables); } } _ => { warn!("Skipping optimizer rule 'DynamicPartitionPruning'"); return HashMap::new(); } } break; } else { // Move on to next step current_plan = current_plan.inputs()[0].clone(); } } tables } // TODO: Add better handling for when a table is read in more than once // https://github.com/dask-contrib/dask-sql/issues/1121 fn check_table_overlaps( m1: &HashMap, m2: &HashMap, m3: &HashMap, ) -> bool { m1.keys().any(|k| m2.contains_key(k)) || m2.keys().any(|k| m1.contains_key(k)) || m1.keys().any(|k| m3.contains_key(k)) || m3.keys().any(|k| m1.contains_key(k)) || m2.keys().any(|k| m3.contains_key(k)) || m3.keys().any(|k| m2.contains_key(k)) } fn get_filepath(plan: &LogicalPlan) -> Option<&String> { match plan { LogicalPlan::TableScan(scan) => scan .source .as_any() .downcast_ref::()? .filepath(), _ => None, } } fn get_table_size(plan: &LogicalPlan) -> Option { match plan { LogicalPlan::TableScan(scan) => scan .source .as_any() .downcast_ref::()? .statistics() .map(|stats| stats.get_row_count() as usize), _ => None, } } // This function parses through the LogicalPlan, grabs any aliases, and adds them to a HashMap // where the key is the alias name and the value is the table name fn gather_aliases(plan: &LogicalPlan) -> HashMap { let mut current_plan = plan.clone(); let mut aliases = HashMap::new(); loop { if current_plan.inputs().is_empty() { break; } else if current_plan.inputs().len() > 1 { match current_plan { LogicalPlan::Join(ref j) => { // Recurse on left and right inputs of Join let (left_aliases, right_aliases) = (gather_aliases(&j.left), gather_aliases(&j.right)); // Add left_aliases and right_aliases to HashMap aliases.extend(left_aliases); aliases.extend(right_aliases); } LogicalPlan::CrossJoin(ref c) => { // Recurse on left and right inputs of CrossJoin let (left_aliases, right_aliases) = (gather_aliases(&c.left), gather_aliases(&c.right)); // Add left_aliases and right_aliases to HashMap aliases.extend(left_aliases); aliases.extend(right_aliases); } LogicalPlan::Union(ref u) => { // Recurse on inputs vector of Union for input in &u.inputs { let union_aliases = gather_aliases(input); // Add union_aliases to HashMap aliases.extend(union_aliases); } } _ => { return HashMap::new(); } } break; } else { if let LogicalPlan::SubqueryAlias(ref s) = current_plan { match *s.input { LogicalPlan::TableScan(ref t) => { aliases.insert(s.alias.to_string(), t.table_name.to_string().clone()); } // Sometimes a TableScan is immediately followed by a Projection, so we can // still use the alias for the table LogicalPlan::Projection(ref p) => { if let LogicalPlan::TableScan(ref t) = *p.input { aliases.insert(s.alias.to_string(), t.table_name.to_string().clone()); } } _ => (), } } // Move on to next step current_plan = current_plan.inputs()[0].clone(); } } aliases } // Wrapper for floats, since they are not hashable #[derive(Clone, Copy, Debug, PartialEq, PartialOrd)] struct FloatWrapper(f64); impl Eq for FloatWrapper {} impl Hash for FloatWrapper { fn hash(&self, state: &mut H) { // Convert the f64 to a u64 using transmute let bits: u64 = self.0.to_bits(); // Use the u64's hash implementation bits.hash(state); } } // Wrapper for possible row value types #[derive(Clone, Debug, Eq, Hash, PartialEq)] enum RowValue { String(Option), Int64(Option), Int32(Option), Double(Option), } // This function uses the table name, column name, and filters to read in the relevant columns, // filter out row values, and construct a HashSet of relevant row values for the specified column, // i.e., the column involved in the join fn read_table( table_string: String, field_string: String, tables: HashMap, ) -> Option> { let file_path = tables.get(&table_string).unwrap().filepath.clone(); let paths: fs::ReadDir; let mut files = vec![]; if fs::metadata(&file_path) .map(|metadata| metadata.is_dir()) .unwrap_or(false) { // Obtain filepaths to all relevant Parquet files, e.g., in a directory of Parquet files paths = fs::read_dir(&file_path).unwrap(); for path in paths { files.push(path.unwrap().path().display().to_string()) } } else { // Obtain single Parquet file files.push(file_path); } // Using the filepaths to the Parquet tables, obtain the schemas of the relevant tables let schema: &Type = &SerializedFileReader::try_from(files[0].clone()) .unwrap() .metadata() .file_metadata() .schema() .clone(); // Use the schemas of the relevant tables to obtain the physical type of the relevant columns let physical_type = get_physical_type(schema, field_string.clone()); // A TableScan may include existing filters. These conditions should be used to filter the data // after being read. Therefore, the columns involved in these filters should be read in as well let filters = tables.get(&table_string).unwrap().filters.clone(); let filtered_fields = get_filtered_fields(&filters, schema, field_string.clone()); let filtered_string = filtered_fields.0; let filtered_types = filtered_fields.1; let filtered_names = filtered_fields.2; if filters.len() != filtered_names.len() { warn!("Unable to check existing filters for optimizer rule 'DynamicPartitionPruning'"); return None; } // Specify which columns to include in the reader, then read in the rows let repetition = get_repetition(schema, field_string.clone()); let physical_type = physical_type.unwrap().to_string(); let projection_schema = "message schema { ".to_owned() + &filtered_string + &repetition.unwrap() + " " + &physical_type + " " + &field_string + "; }"; let projection = parse_message_type(&projection_schema).ok(); let mut rows = Vec::new(); for file in files { let reader_result = SerializedFileReader::try_from(&*file.clone()); if let Ok(reader) = reader_result { let row_iter_result = RowIter::from_file_into(Box::new(reader)) .project(projection.clone()) .ok(); if let Some(row_iter) = row_iter_result { rows.extend(row_iter.map(|r| r.expect("Parquet error encountered"))); } else { // TODO: Investigate cases when this would happen rows.clear(); break; } } else { rows.clear(); break; } } if rows.is_empty() { return None; } // Create HashSets for the join column values let mut value_set: HashSet = HashSet::new(); for row in rows { // Since a TableScan may have its own filters, we want to ensure that the values in // value_set satisfy the TableScan filters let mut satisfies_filters = true; let mut row_index = 0; for index in 0..filters.len() { if filtered_names[index] != field_string { let current_type = &filtered_types[index]; match current_type.as_str() { "BYTE_ARRAY" => { let string_value = row.get_string(row_index).ok(); if !satisfies_string(string_value, filters[index].clone()) { satisfies_filters = false; } } "INT64" => { let long_value = row.get_long(row_index).ok(); if !satisfies_int64(long_value, filters[index].clone()) { satisfies_filters = false; } } "INT32" => { let int_value = row.get_int(row_index).ok(); if !satisfies_int32(int_value, filters[index].clone()) { satisfies_filters = false; } } "DOUBLE" => { let double_value = row.get_double(row_index).ok(); if !satisfies_float(double_value, filters[index].clone()) { satisfies_filters = false; } } u => panic!("Unknown PhysicalType {u}"), } row_index += 1; } } // After verifying that the row satisfies all existing filters, we add the column value to // the HashSet if satisfies_filters { match physical_type.as_str() { "BYTE_ARRAY" => { let r = row.get_string(row_index).ok(); value_set.insert(RowValue::String(r.cloned())); } "INT64" => { let r = row.get_long(row_index).ok(); value_set.insert(RowValue::Int64(r)); } "INT32" => { let r = row.get_int(row_index).ok(); value_set.insert(RowValue::Int32(r)); } "DOUBLE" => { let r = row.get_double(row_index).ok(); if let Some(f) = r { value_set.insert(RowValue::Double(Some(FloatWrapper(f)))); } else { value_set.insert(RowValue::Double(None)); } } _ => panic!("Unknown PhysicalType"), } } } Some(value_set) } // A column has a physical_type (INT64, etc.) that needs to be included when specifying which // columns to read in. To get the physical_type, we grab it from the schema fn get_physical_type(schema: &Type, field: String) -> Option { match schema { Type::GroupType { basic_info: _, fields, } => { for f in fields { let match_field = &*f.clone(); match match_field { Type::PrimitiveType { basic_info, physical_type, .. } => { if basic_info.name() == field { return Some(*physical_type); } } _ => return None, } } None } _ => None, } } // A column has a repetition (i.e., REQUIRED or OPTIONAL) that needs to be included when specifying // which columns to read in. To get the repetition, we grab it from the schema fn get_repetition(schema: &Type, field: String) -> Option { match schema { Type::GroupType { basic_info: _, fields, } => { for f in fields { let match_field = &*f.clone(); match match_field { Type::PrimitiveType { basic_info, .. } => { if basic_info.name() == field { return Some(basic_info.repetition().to_string()); } } _ => return None, } } None } _ => None, } } // This is a helper function to deal with TableScan filters for reading in the data. The first // value returned is a string representation of the projection used to read in the relevant // columns. The second value returned is a vector of the physical_type of each column that has has // a filter, in the order that they are being read. The third value returned is a vector of the // column names, in the order that they are being read. fn get_filtered_fields( filters: &Vec, schema: &Type, field: String, ) -> (String, Vec, Vec) { // Used to create a string representation of the projection // for the TableScan filters to be read let mut filtered_fields = vec![]; // All physical types involved in TableScan filters let mut filtered_types = vec![]; // All columns involved in TableScan filters let mut filtered_columns = vec![]; for filter in filters { match filter { Expr::BinaryExpr(b) => { if let Expr::Column(column) = &*b.left { push_filtered_fields( column, schema, field.clone(), &mut filtered_fields, &mut filtered_columns, &mut filtered_types, ); } } Expr::IsNotNull(e) => { if let Expr::Column(column) = &**e { push_filtered_fields( column, schema, field.clone(), &mut filtered_fields, &mut filtered_columns, &mut filtered_types, ); } } _ => (), } } (filtered_fields.join(""), filtered_types, filtered_columns) } // Helper function for get_filtered_fields fn push_filtered_fields( column: &Column, schema: &Type, field: String, filtered_fields: &mut Vec, filtered_columns: &mut Vec, filtered_types: &mut Vec, ) { let current_field = column.name.clone(); let physical_type = get_physical_type(schema, current_field.clone()) .unwrap() .to_string(); if current_field != field { let repetition = get_repetition(schema, current_field.clone()); filtered_fields.push(repetition.unwrap()); filtered_fields.push(" ".to_string()); filtered_fields.push(physical_type.clone()); filtered_fields.push(" ".to_string()); filtered_fields.push(current_field.clone()); filtered_fields.push("; ".to_string()); } filtered_types.push(physical_type); filtered_columns.push(current_field); } // Returns a boolean representing whether a string satisfies a given filter fn satisfies_string(string_value: Option<&String>, filter: Expr) -> bool { match filter { Expr::BinaryExpr(b) => match b.op { Operator::Eq => Expr::Literal(ScalarValue::Utf8(string_value.cloned())) == *b.right, Operator::NotEq => Expr::Literal(ScalarValue::Utf8(string_value.cloned())) != *b.right, _ => { panic!("Unknown satisfies_string operator"); } }, Expr::IsNotNull(_) => string_value.is_some(), _ => { panic!("Unknown satisfies_string Expr"); } } } // Returns a boolean representing whether an Int64 satisfies a given filter fn satisfies_int64(long_value: Option, filter: Expr) -> bool { match filter { Expr::BinaryExpr(b) => { let filter_value = *b.right; let int_value: i64 = match filter_value { Expr::Literal(ScalarValue::Int64(i)) => i.unwrap(), Expr::Literal(ScalarValue::Int32(i)) => i64::from(i.unwrap()), Expr::Literal(ScalarValue::Float64(i)) => i.unwrap() as i64, Expr::Literal(ScalarValue::TimestampNanosecond(i, None)) => i.unwrap(), Expr::Literal(ScalarValue::Date32(i)) => i64::from(i.unwrap()), // TODO: Add logic to check if the string can be converted to a timestamp Expr::Literal(ScalarValue::Utf8(_)) => return false, _ => { panic!("Unknown ScalarValue type {filter_value}"); } }; let filter_value = Expr::Literal(ScalarValue::Int64(Some(int_value))); match b.op { Operator::Eq => Expr::Literal(ScalarValue::Int64(long_value)) == filter_value, Operator::NotEq => Expr::Literal(ScalarValue::Int64(long_value)) != filter_value, Operator::Gt => Expr::Literal(ScalarValue::Int64(long_value)) > filter_value, Operator::Lt => Expr::Literal(ScalarValue::Int64(long_value)) < filter_value, Operator::GtEq => Expr::Literal(ScalarValue::Int64(long_value)) >= filter_value, Operator::LtEq => Expr::Literal(ScalarValue::Int64(long_value)) <= filter_value, _ => { panic!("Unknown satisfies_int64 operator"); } } } Expr::IsNotNull(_) => long_value.is_some(), _ => { panic!("Unknown satisfies_int64 Expr"); } } } // Returns a boolean representing whether an Int32 satisfies a given filter fn satisfies_int32(long_value: Option, filter: Expr) -> bool { match filter { Expr::BinaryExpr(b) => { let filter_value = *b.right; let int_value: i32 = match filter_value { Expr::Literal(ScalarValue::Int64(i)) => i.unwrap() as i32, Expr::Literal(ScalarValue::Int32(i)) => i.unwrap(), Expr::Literal(ScalarValue::Float64(i)) => i.unwrap() as i32, _ => { panic!("Unknown ScalarValue type {filter_value}"); } }; let filter_value = Expr::Literal(ScalarValue::Int32(Some(int_value))); match b.op { Operator::Eq => Expr::Literal(ScalarValue::Int32(long_value)) == filter_value, Operator::NotEq => Expr::Literal(ScalarValue::Int32(long_value)) != filter_value, Operator::Gt => Expr::Literal(ScalarValue::Int32(long_value)) > filter_value, Operator::Lt => Expr::Literal(ScalarValue::Int32(long_value)) < filter_value, Operator::GtEq => Expr::Literal(ScalarValue::Int32(long_value)) >= filter_value, Operator::LtEq => Expr::Literal(ScalarValue::Int32(long_value)) <= filter_value, _ => { panic!("Unknown satisfies_int32 operator"); } } } Expr::IsNotNull(_) => long_value.is_some(), _ => { panic!("Unknown satisfies_int32 Expr"); } } } // Returns a boolean representing whether an Float64 satisfies a given filter fn satisfies_float(long_value: Option, filter: Expr) -> bool { match filter { Expr::BinaryExpr(b) => { let filter_value = *b.right; let float_value: f64 = match filter_value { Expr::Literal(ScalarValue::Int64(i)) => i.unwrap() as f64, Expr::Literal(ScalarValue::Int32(i)) => i.unwrap() as f64, Expr::Literal(ScalarValue::Float64(i)) => i.unwrap(), _ => { panic!("Unknown ScalarValue type {filter_value}"); } }; let filter_value = Expr::Literal(ScalarValue::Float64(Some(float_value))); match b.op { Operator::Eq => Expr::Literal(ScalarValue::Float64(long_value)) == filter_value, Operator::NotEq => Expr::Literal(ScalarValue::Float64(long_value)) != filter_value, Operator::Gt => Expr::Literal(ScalarValue::Float64(long_value)) > filter_value, Operator::Lt => Expr::Literal(ScalarValue::Float64(long_value)) < filter_value, Operator::GtEq => Expr::Literal(ScalarValue::Float64(long_value)) >= filter_value, Operator::LtEq => Expr::Literal(ScalarValue::Float64(long_value)) <= filter_value, _ => { panic!("Unknown satisfies_float operator"); } } } Expr::IsNotNull(_) => long_value.is_some(), _ => { panic!("Unknown satisfies_float Expr"); } } } // Used to simplify the signature of combine_sets type RowHashSet = HashSet; type RowOptionHashSet = Option; type RowTuple = (RowOptionHashSet, RowOptionHashSet); type RowVec = Vec; // Given a vector of hashsets to be set as TableScan filters, a vector of tuples representing the // tables involved in a join, a vector of tuples representing the columns involved in a join, and // a hashset of fact tables in the query; return a hashmap where the key is a tuple of the table // and column names, and the value is the hashset representing the INLIST filter specified in the // TableScan. fn combine_sets( join_values: RowVec, join_tables: Vec<(String, String)>, join_fields: Vec<(String, String)>, fact_tables: HashSet, ) -> HashMap<(String, String), HashSet> { let mut sets: HashMap<(String, String), HashSet> = HashMap::new(); for i in 0..join_values.len() { // Case when we were able to read in both tables involved in the join if let (Some(set1), Some(set2)) = (&join_values[i].0, &join_values[i].1) { // The INLIST vector will be the intersection of both hashsets let set_intersection = set1.intersection(set2); let mut values = HashSet::new(); for value in set_intersection { values.insert(value.clone()); } let current_table = join_tables[i].0.clone(); // We only create INLIST filters for fact tables if fact_tables.contains(¤t_table) { let current_field = join_fields[i].0.clone(); add_to_existing_set(&mut sets, values.clone(), current_table, current_field); } let current_table = join_tables[i].1.clone(); // We only create INLIST filters for fact tables if fact_tables.contains(¤t_table) { let current_field = join_fields[i].1.clone(); add_to_existing_set(&mut sets, values.clone(), current_table, current_field); } // Case when we were only able to read in the left table of the join } else if let Some(values) = &join_values[i].0 { let current_table = join_tables[i].0.clone(); // We only create INLIST filters for fact tables if fact_tables.contains(¤t_table) { let current_field = join_fields[i].0.clone(); add_to_existing_set(&mut sets, values.clone(), current_table, current_field); } let current_table = join_tables[i].1.clone(); // We only create INLIST filters for fact tables if fact_tables.contains(¤t_table) { let current_field = join_fields[i].1.clone(); add_to_existing_set(&mut sets, values.clone(), current_table, current_field); } // Case when we were only able to read in the right table of the join } else if let Some(values) = &join_values[i].1 { let current_table = join_tables[i].0.clone(); // We only create INLIST filters for fact tables if fact_tables.contains(¤t_table) { let current_field = join_fields[i].0.clone(); add_to_existing_set(&mut sets, values.clone(), current_table, current_field); } let current_table = join_tables[i].1.clone(); // We only create INLIST filters for fact tables if fact_tables.contains(¤t_table) { let current_field = join_fields[i].1.clone(); add_to_existing_set(&mut sets, values.clone(), current_table, current_field); } } } sets } // Given a mutable hashmap (the hashmap which will eventually be returned by the `combine_sets` // function), a hashset of values, a table name, and a column name; insert the hashset of values // into the hashmap, where the key is a tuple of the table and column names. fn add_to_existing_set( sets: &mut HashMap<(String, String), HashSet>, values: HashSet, current_table: String, current_field: String, ) { let existing_set = sets.get(&(current_table.clone(), current_field.clone())); match existing_set { // If the tuple for (current_table, current_field) already exists, then we want to combine // the existing set with the new hashset being inserted; to do this, we take the // intersection of both sets. Some(s) => { let s = s.clone(); let v = values.iter().cloned().collect::>(); let s = s.intersection(&v); let mut set_intersection = HashSet::new(); for i in s { set_intersection.insert(i.clone()); } sets.insert((current_table, current_field), set_intersection.clone()); } // If the tuple for (current_table, current_field) does not already exist as a key in the // hashmap, then simply create it and set the hashset as the value None => { sets.insert((current_table, current_field), values); } } } // Given a LogicalPlan and a hashmap where the key is a tuple containing a table name and column // and the value is a hashset of unique row values, parse the LogicalPlan and insert INLIST filters // at the TableScan level. fn optimize_table_scans( plan: &LogicalPlan, filter_values: HashMap<(String, String), HashSet>, ) -> Result> { // Replaces existing TableScan with a new TableScan which includes // the new binary expression filter created from reading in the join columns match plan { LogicalPlan::TableScan(t) => { let table_name = t.table_name.to_string(); let table_filters: HashMap<(String, String), HashSet> = filter_values .iter() .filter(|(key, _value)| key.0 == table_name) .map(|(key, value)| ((key.0.to_owned(), key.1.to_owned()), value.clone())) .collect(); let mut updated_filters = t.filters.clone(); for (key, value) in table_filters.iter() { let current_expr = format_inlist_expr(value.clone(), key.0.to_owned(), key.1.to_owned()); if let Some(e) = current_expr { updated_filters.push(e); } } let scan = LogicalPlan::TableScan(TableScan { table_name: t.table_name.clone(), source: t.source.clone(), projection: t.projection.clone(), projected_schema: t.projected_schema.clone(), filters: updated_filters, fetch: t.fetch, }); Ok(Some(scan)) } _ => optimize_children(plan, filter_values), } } // Given a hashset of values, a table name, and a column name, return a DataFusion INLIST Expr fn format_inlist_expr( value_set: HashSet, join_table: String, join_field: String, ) -> Option { let expr = Box::new(Expr::Column(Column::new(Some(join_table), join_field))); let mut list: Vec = vec![]; // Need to correctly format the ScalarValue type for value in value_set { if let RowValue::String(s) = value { if s.is_some() { let v = Expr::Literal(ScalarValue::Utf8(s)); list.push(v); } } else if let RowValue::Int64(l) = value { if l.is_some() { let v = Expr::Literal(ScalarValue::Int64(l)); list.push(v); } } else if let RowValue::Int32(i) = value { if i.is_some() { let v = Expr::Literal(ScalarValue::Int32(i)); list.push(v); } } else if let RowValue::Double(Some(f)) = value { let v = Expr::Literal(ScalarValue::Float64(Some(f.0))); list.push(v); } } if list.is_empty() { None } else { Some(Expr::InList(InList { expr, list, negated: false, })) } } // Given a LogicalPlan and the same hashmap as the `optimize_table_scans` function, correctly // iterate through the LogicalPlan nodes. Similar to DataFusion's `optimize_children` function, but // recurses on the `optimize_table_scans` function instead. fn optimize_children( plan: &LogicalPlan, filter_values: HashMap<(String, String), HashSet>, ) -> Result> { let new_exprs = plan.expressions(); let mut new_inputs = Vec::with_capacity(plan.inputs().len()); let mut plan_is_changed = false; for input in plan.inputs() { let new_input = optimize_table_scans(input, filter_values.clone())?; plan_is_changed = plan_is_changed || new_input.is_some(); new_inputs.push(new_input.unwrap_or_else(|| input.clone())) } if plan_is_changed { Ok(Some(plan.with_new_exprs(new_exprs, &new_inputs)?)) } else { Ok(None) } } ================================================ FILE: src/sql/optimizer/join_reorder.rs ================================================ //! Join reordering based on the paper "Improving Join Reordering for Large Scale Distributed Computing" //! https://ieeexplore.ieee.org/document/9378281 use std::collections::HashSet; use datafusion_python::{ datafusion_common::{Column, Result}, datafusion_expr::{Expr, Join, JoinType, LogicalPlan, LogicalPlanBuilder}, datafusion_optimizer::{utils, utils::split_conjunction, OptimizerConfig, OptimizerRule}, }; use log::warn; use crate::sql::table::DaskTableSource; pub struct JoinReorder { /// Ratio of the size of the dimension tables to fact tables fact_dimension_ratio: f64, /// Maximum number of fact tables to allow in a join max_fact_tables: usize, /// Whether to preserve user-defined order of unfiltered dimensions preserve_user_order: bool, /// Constant to use when determining the number of rows produced by a /// filtered relation filter_selectivity: f64, } impl JoinReorder { pub fn new( fact_dimension_ratio: Option, max_fact_tables: Option, preserve_user_order: Option, filter_selectivity: Option, ) -> Self { Self { // FIXME: Default value for fact_dimension_ratio should be 0.3, not 0.7 fact_dimension_ratio: fact_dimension_ratio.unwrap_or(0.7), max_fact_tables: max_fact_tables.unwrap_or(2), preserve_user_order: preserve_user_order.unwrap_or(true), filter_selectivity: filter_selectivity.unwrap_or(1.0), } } } impl OptimizerRule for JoinReorder { fn name(&self) -> &str { "join_reorder" } fn try_optimize( &self, plan: &LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { let original_plan = plan.clone(); // Recurse down first // We want the equivalent of Spark's transformUp here let plan = utils::optimize_children(self, plan, _config)?; match &plan { Some(LogicalPlan::Join(join)) if join.join_type == JoinType::Inner => { optimize_join(self, plan.as_ref().unwrap(), join) } Some(plan) => Ok(Some(plan.clone())), None => match &original_plan { LogicalPlan::Join(join) if join.join_type == JoinType::Inner => { optimize_join(self, &original_plan, join) } _ => Ok(None), }, } } } fn optimize_join( rule: &JoinReorder, plan: &LogicalPlan, join: &Join, ) -> Result> { // FIXME: Check fact/fact join logic if !is_supported_join(join) { return Ok(Some(plan.clone())); } // Extract the relations and join conditions let (rels, conds) = extract_inner_joins(plan); let mut join_conds = HashSet::new(); for cond in &conds { match cond { (Expr::Column(l), Expr::Column(r)) => { join_conds.insert((l.clone(), r.clone())); } _ => { return Ok(Some(plan.clone())); } } } // Split rels into facts and dims let largest_rel_size = rels.iter().map(|rel| rel.size).max().unwrap() as f64; // Vectors for the fact and dimension tables, respectively let mut facts = vec![]; let mut dims = vec![]; for rel in &rels { // If the ratio is larger than the fact_dimension_ratio, it is a fact table // Else, it is a dimension table if rel.size as f64 / largest_rel_size > rule.fact_dimension_ratio { facts.push(rel.clone()); } else { dims.push(rel.clone()); } } if facts.is_empty() || dims.is_empty() { return Ok(Some(plan.clone())); } if facts.len() > rule.max_fact_tables { return Ok(Some(plan.clone())); } // Get list of dimension tables without a selective predicate let mut unfiltered_dimensions = get_unfiltered_dimensions(&dims); if !rule.preserve_user_order { unfiltered_dimensions.sort_by(|a, b| a.size.cmp(&b.size)); } // Get list of dimension tables with a selective predicate and sort it let filtered_dimensions = get_filtered_dimensions(&dims); let mut filtered_dimensions: Vec = filtered_dimensions .iter() .map(|rel| Relation { plan: rel.plan.clone(), size: (rel.size as f64 * rule.filter_selectivity) as usize, }) .collect(); filtered_dimensions.sort_by(|a, b| a.size.cmp(&b.size)); // Merge both the lists of dimensions by giving user order // the preference for tables without a selective predicate, // whereas for tables with selective predicates giving preference // to smaller tables. When comparing the top of both // the lists, if size of the top table in the selective predicate // list is smaller than top of the other list, choose it otherwise // vice-versa. // This algorithm is a greedy approach where smaller // joins with filtered dimension table are preferred for execution // earlier than other Joins to improve Join performance. We try to keep // the user order intact when unsure about reordering to make sure // regressions are minimized. let mut result = vec![]; while !filtered_dimensions.is_empty() || !unfiltered_dimensions.is_empty() { if !filtered_dimensions.is_empty() { if !unfiltered_dimensions.is_empty() { if filtered_dimensions[0].size < unfiltered_dimensions[0].size { result.push(filtered_dimensions.remove(0)); } else { result.push(unfiltered_dimensions.remove(0)); } } else { result.push(filtered_dimensions.remove(0)); } } else { result.push(unfiltered_dimensions.remove(0)); } } let dim_plans: Vec = result.iter().map(|rel| rel.plan.clone()).collect(); let optimized = if facts.len() == 1 { build_join_tree(&facts[0].plan, &dim_plans, &mut join_conds)? } else { // Build one join tree for each fact table let fact_dim_joins = facts .iter() .map(|f| build_join_tree(&f.plan, &dim_plans, &mut join_conds)) .collect::>>()?; // Join the trees together build_join_tree(&fact_dim_joins[0], &fact_dim_joins[1..], &mut join_conds)? }; if join_conds.is_empty() { Ok(Some(optimized)) } else { Ok(Some(plan.clone())) } } /// Represents a Fact or Dimension table, possibly nested in a filter #[derive(Clone, Debug)] struct Relation { /// Plan containing the table scan for the fact or dimension table /// May also contain Filter and SubqueryAlias plan: LogicalPlan, /// Estimated size of the underlying table before any filtering is applied size: usize, } impl Relation { fn new(plan: LogicalPlan) -> Self { let size = get_table_size(&plan); match size { Some(s) => Self { plan, size: s }, None => { warn!("Table statistics couldn't be obtained; assuming 100 rows"); Self { plan, size: 100 } } } } /// Determine if this plan contains any filters fn has_filter(&self) -> bool { has_filter(&self.plan) } } fn has_filter(plan: &LogicalPlan) -> bool { /// We want to ignore "IsNotNull" filters that are added for join keys since they exist /// for most dimension tables fn is_real_filter(predicate: &Expr) -> bool { let exprs = split_conjunction(predicate); let x = exprs .iter() .filter(|e| !matches!(e, Expr::IsNotNull(_))) .count(); x > 0 } match plan { LogicalPlan::Filter(filter) => is_real_filter(&filter.predicate), LogicalPlan::TableScan(scan) => scan.filters.iter().any(is_real_filter), _ => plan.inputs().iter().any(|child| has_filter(child)), } } /// Simple Join Constraint: Only INNER Joins are considered /// which can be composed of other Joins too. But apart /// from the Joins, none of the operator in both the left and /// right side of the join should be non-deterministic, or have /// output greater than the input to the operator. For instance, /// Filter would be allowed operator as it reduces the output /// over input, but a project adding extra column will not /// be allowed. It is difficult to reason about operators that /// add extra to output when dealing with just table sizes, so /// instead we only allowed operators from selected set of /// operators fn is_supported_join(join: &Join) -> bool { // FIXME: Check for deterministic filter expressions fn is_supported_rel(plan: &LogicalPlan) -> bool { match plan { LogicalPlan::Join(join) => { join.join_type == JoinType::Inner // FIXME: Need to support join filters correctly && join.filter.is_none() && is_supported_rel(&join.left) && is_supported_rel(&join.right) } LogicalPlan::Filter(filter) => is_supported_rel(&filter.input), LogicalPlan::SubqueryAlias(sq) => is_supported_rel(&sq.input), LogicalPlan::TableScan(_) => true, _ => false, } } is_supported_rel(&LogicalPlan::Join(join.clone())) } /// Extracts items of consecutive inner joins and join conditions /// This method works for bushy trees and left/right deep trees fn extract_inner_joins(plan: &LogicalPlan) -> (Vec, HashSet<(Expr, Expr)>) { fn _extract_inner_joins( plan: &LogicalPlan, rels: &mut Vec, conds: &mut HashSet<(Expr, Expr)>, ) { match plan { LogicalPlan::Join(join) if join.join_type == JoinType::Inner && join.filter.is_none() => { _extract_inner_joins(&join.left, rels, conds); _extract_inner_joins(&join.right, rels, conds); for (l, r) in &join.on { conds.insert((l.clone(), r.clone())); } } /* FIXME: Need to support join filters correctly LogicalPlan::Join(join) if join.join_type == JoinType::Inner => { _extract_inner_joins(&join.left, rels, conds); _extract_inner_joins(&join.right, rels, conds); for (l, r) in &join.on { conds.insert((l.clone(), r.clone())); } // Need to save this info somewhere let join_filter = join.filter.as_ref().unwrap(); } */ _ => { if find_join(plan).is_some() { for x in plan.inputs() { _extract_inner_joins(x, rels, conds); } } else { // Leaf node rels.push(plan.clone()) } } } } let mut rels = vec![]; let mut conds = HashSet::new(); _extract_inner_joins(plan, &mut rels, &mut conds); let rels = rels.into_iter().map(Relation::new).collect(); (rels, conds) } /// Find first (top-level) join in plan fn find_join(plan: &LogicalPlan) -> Option { match plan { LogicalPlan::Join(join) => Some(join.clone()), other => { if other.inputs().is_empty() { None } else { for input in &other.inputs() { if let Some(join) = find_join(input) { return Some(join); } } None } } } } fn get_unfiltered_dimensions(dims: &[Relation]) -> Vec { dims.iter().filter(|t| !t.has_filter()).cloned().collect() } fn get_filtered_dimensions(dims: &[Relation]) -> Vec { dims.iter().filter(|t| t.has_filter()).cloned().collect() } fn build_join_tree( fact: &LogicalPlan, dims: &[LogicalPlan], conds: &mut HashSet<(Column, Column)>, ) -> Result { let mut b = LogicalPlanBuilder::from(fact.clone()); for dim in dims { // Find join keys between the fact and this dim let mut join_keys = vec![]; for (l, r) in conds.iter() { if (b.schema().index_of_column(l).is_ok() && dim.schema().index_of_column(r).is_ok()) || b.schema().index_of_column(r).is_ok() && dim.schema().index_of_column(l).is_ok() { join_keys.push((l.clone(), r.clone())); } } if !join_keys.is_empty() { let left_keys: Vec = join_keys.iter().map(|(l, _r)| l.clone()).collect(); let right_keys: Vec = join_keys.iter().map(|(_l, r)| r.clone()).collect(); for key in join_keys { conds.remove(&key); } /* FIXME: Build join with join_keys when needed self.join( right: LogicalPlan, join_type: JoinType, join_keys: (Vec>, Vec>), filter: Option, ) */ b = b.join(dim.clone(), JoinType::Inner, (left_keys, right_keys), None)?; } } b.build() } fn get_table_size(plan: &LogicalPlan) -> Option { match plan { LogicalPlan::TableScan(scan) => scan .source .as_any() .downcast_ref::() .expect("should be a DaskTableSource") .statistics() .map(|stats| stats.get_row_count() as usize), _ => get_table_size(plan.inputs()[0]), } } ================================================ FILE: src/sql/optimizer/utils.rs ================================================ // Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. //! Collection of utility functions that are leveraged by the query optimizer rules use std::{ collections::{BTreeSet, HashMap}, sync::Arc, }; use datafusion_python::{ datafusion_common::{Column, DFSchema, DFSchemaRef, Result}, datafusion_expr::{ and, expr::{Alias, BinaryExpr}, expr_rewriter::{replace_col, strip_outer_reference}, logical_plan::{Filter, LogicalPlan}, Expr, LogicalPlanBuilder, Operator, }, datafusion_optimizer::optimizer::{OptimizerConfig, OptimizerRule}, }; use log::{debug, trace}; #[allow(dead_code)] /// Convenience rule for writing optimizers: recursively invoke /// optimize on plan's children and then return a node of the same /// type. Useful for optimizer rules which want to leave the type /// of plan unchanged but still apply to the children. /// This also handles the case when the `plan` is a [`LogicalPlan::Explain`]. /// /// Returning `Ok(None)` indicates that the plan can't be optimized by the `optimizer`. pub fn optimize_children( optimizer: &impl OptimizerRule, plan: &LogicalPlan, config: &dyn OptimizerConfig, ) -> Result> { let mut new_inputs = Vec::with_capacity(plan.inputs().len()); let mut plan_is_changed = false; for input in plan.inputs() { let new_input = optimizer.try_optimize(input, config)?; plan_is_changed = plan_is_changed || new_input.is_some(); new_inputs.push(new_input.unwrap_or_else(|| input.clone())) } if plan_is_changed { Ok(Some(plan.with_new_inputs(&new_inputs)?)) } else { Ok(None) } } /// Splits a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` /// /// See [`split_conjunction_owned`] for more details and an example. pub fn split_conjunction(expr: &Expr) -> Vec<&Expr> { split_conjunction_impl(expr, vec![]) } fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<&'a Expr> { match expr { Expr::BinaryExpr(BinaryExpr { right, op: Operator::And, left, }) => { let exprs = split_conjunction_impl(left, exprs); split_conjunction_impl(right, exprs) } Expr::Alias(Alias { expr, .. }) => split_conjunction_impl(expr, exprs), other => { exprs.push(other); exprs } } } /// Extract join predicates from the correclated subquery. /// The join predicate means that the expression references columns /// from both the subquery and outer table or only from the outer table. /// /// Returns join predicates and subquery(extracted). pub(crate) fn extract_join_filters(maybe_filter: &LogicalPlan) -> Result<(Vec, LogicalPlan)> { if let LogicalPlan::Filter(plan_filter) = maybe_filter { let subquery_filter_exprs = split_conjunction(&plan_filter.predicate); let (join_filters, subquery_filters) = find_join_exprs(subquery_filter_exprs)?; // if the subquery still has filter expressions, restore them. let mut plan = LogicalPlanBuilder::from((*plan_filter.input).clone()); if let Some(expr) = conjunction(subquery_filters) { plan = plan.filter(expr)? } Ok((join_filters, plan.build()?)) } else { Ok((vec![], maybe_filter.clone())) } } #[allow(dead_code)] /// Splits an owned conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` /// /// This is often used to "split" filter expressions such as `col1 = 5 /// AND col2 = 10` into [`col1 = 5`, `col2 = 10`]; /// /// # Example /// ``` /// # use datafusion_python::datafusion_expr::{col, lit}; /// # use datafusion_python::datafusion_optimizer::utils::split_conjunction_owned; /// // a=1 AND b=2 /// let expr = col("a").eq(lit(1)).and(col("b").eq(lit(2))); /// /// // [a=1, b=2] /// let split = vec![ /// col("a").eq(lit(1)), /// col("b").eq(lit(2)), /// ]; /// /// // use split_conjunction_owned to split them /// assert_eq!(split_conjunction_owned(expr), split); /// ``` pub fn split_conjunction_owned(expr: Expr) -> Vec { split_binary_owned(expr, Operator::And) } #[allow(dead_code)] /// Splits an owned binary operator tree [`Expr`] such as `A B C` => `[A, B, C]` /// /// This is often used to "split" expressions such as `col1 = 5 /// AND col2 = 10` into [`col1 = 5`, `col2 = 10`]; /// /// # Example /// ``` /// # use datafusion_python::datafusion_expr::{col, lit, Operator}; /// # use datafusion_python::datafusion_optimizer::utils::split_binary_owned; /// # use std::ops::Add; /// // a=1 + b=2 /// let expr = col("a").eq(lit(1)).add(col("b").eq(lit(2))); /// /// // [a=1, b=2] /// let split = vec![ /// col("a").eq(lit(1)), /// col("b").eq(lit(2)), /// ]; /// /// // use split_binary_owned to split them /// assert_eq!(split_binary_owned(expr, Operator::Plus), split); /// ``` pub fn split_binary_owned(expr: Expr, op: Operator) -> Vec { split_binary_owned_impl(expr, op, vec![]) } #[allow(dead_code)] fn split_binary_owned_impl(expr: Expr, operator: Operator, mut exprs: Vec) -> Vec { match expr { Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => { let exprs = split_binary_owned_impl(*left, operator, exprs); split_binary_owned_impl(*right, operator, exprs) } Expr::Alias(Alias { expr, .. }) => split_binary_owned_impl(*expr, operator, exprs), other => { exprs.push(other); exprs } } } #[allow(dead_code)] /// Splits an binary operator tree [`Expr`] such as `A B C` => `[A, B, C]` /// /// See [`split_binary_owned`] for more details and an example. pub fn split_binary(expr: &Expr, op: Operator) -> Vec<&Expr> { split_binary_impl(expr, op, vec![]) } #[allow(dead_code)] fn split_binary_impl<'a>( expr: &'a Expr, operator: Operator, mut exprs: Vec<&'a Expr>, ) -> Vec<&'a Expr> { match expr { Expr::BinaryExpr(BinaryExpr { right, op, left }) if *op == operator => { let exprs = split_binary_impl(left, operator, exprs); split_binary_impl(right, operator, exprs) } Expr::Alias(Alias { expr, .. }) => split_binary_impl(expr, operator, exprs), other => { exprs.push(other); exprs } } } /// Combines an array of filter expressions into a single filter /// expression consisting of the input filter expressions joined with /// logical AND. /// /// Returns None if the filters array is empty. /// /// # Example /// ``` /// # use datafusion_python::datafusion_expr::{col, lit}; /// # use datafusion_python::datafusion_optimizer::utils::conjunction; /// // a=1 AND b=2 /// let expr = col("a").eq(lit(1)).and(col("b").eq(lit(2))); /// /// // [a=1, b=2] /// let split = vec![ /// col("a").eq(lit(1)), /// col("b").eq(lit(2)), /// ]; /// /// // use conjunction to join them together with `AND` /// assert_eq!(conjunction(split), Some(expr)); /// ``` pub fn conjunction(filters: impl IntoIterator) -> Option { filters.into_iter().reduce(|accum, expr| accum.and(expr)) } #[allow(dead_code)] /// Combines an array of filter expressions into a single filter /// expression consisting of the input filter expressions joined with /// logical OR. /// /// Returns None if the filters array is empty. pub fn disjunction(filters: impl IntoIterator) -> Option { filters.into_iter().reduce(|accum, expr| accum.or(expr)) } /// returns a new [LogicalPlan] that wraps `plan` in a [LogicalPlan::Filter] with /// its predicate be all `predicates` ANDed. #[allow(dead_code)] pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result { // reduce filters to a single filter with an AND let predicate = predicates .iter() .skip(1) .fold(predicates[0].clone(), |acc, predicate| { and(acc, (*predicate).to_owned()) }); Ok(LogicalPlan::Filter(Filter::try_new( predicate, Arc::new(plan), )?)) } /// Looks for correlating expressions: for example, a binary expression with one field from the subquery, and /// one not in the subquery (closed upon from outer scope) /// /// # Arguments /// /// * `exprs` - List of expressions that may or may not be joins /// /// # Return value /// /// Tuple of (expressions containing joins, remaining non-join expressions) pub fn find_join_exprs(exprs: Vec<&Expr>) -> Result<(Vec, Vec)> { let mut joins = vec![]; let mut others = vec![]; for filter in exprs.into_iter() { // If the expression contains correlated predicates, add it to join filters if filter.contains_outer() { if !matches!(filter, Expr::BinaryExpr(BinaryExpr{ left, op: Operator::Eq, right }) if left.eq(right)) { joins.push(strip_outer_reference((*filter).clone())); } } else { others.push((*filter).clone()); } } Ok((joins, others)) } /// Returns the first (and only) element in a slice, or an error /// /// # Arguments /// /// * `slice` - The slice to extract from /// /// # Return value /// /// The first element, or an error pub fn only_or_err(slice: &[T]) -> Result<&T> { match slice { [it] => Ok(it), [] => Err(datafusion_python::datafusion_common::DataFusionError::Plan( "No items found!".to_owned(), )), _ => Err(datafusion_python::datafusion_common::DataFusionError::Plan( "More than one item found!".to_owned(), )), } } /// merge inputs schema into a single schema. #[allow(dead_code)] pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema { if inputs.len() == 1 { inputs[0].schema().clone().as_ref().clone() } else { inputs .iter() .map(|input| input.schema()) .fold(DFSchema::empty(), |mut lhs, rhs| { lhs.merge(rhs); lhs }) } } pub(crate) fn collect_subquery_cols( exprs: &[Expr], subquery_schema: DFSchemaRef, ) -> Result> { exprs.iter().try_fold(BTreeSet::new(), |mut cols, expr| { let mut using_cols: Vec = vec![]; for col in expr.to_columns()?.into_iter() { if subquery_schema.has_column(&col) { using_cols.push(col); } } cols.extend(using_cols); Result::<_>::Ok(cols) }) } pub(crate) fn replace_qualified_name( expr: Expr, cols: &BTreeSet, subquery_alias: &str, ) -> Result { let alias_cols: Vec = cols .iter() .map(|col| Column::from_qualified_name(format!("{}.{}", subquery_alias, col.name))) .collect(); let replace_map: HashMap<&Column, &Column> = cols.iter().zip(alias_cols.iter()).collect(); replace_col(expr, &replace_map) } #[allow(dead_code)] /// Log the plan in debug/tracing mode after some part of the optimizer runs pub fn log_plan(description: &str, plan: &LogicalPlan) { debug!("{description}:\n{}\n", plan.display_indent()); trace!("{description}::\n{}\n", plan.display_indent_schema()); } #[cfg(test)] mod tests { use std::collections::HashSet; use datafusion_python::{ datafusion::arrow::datatypes::DataType, datafusion_common::Column, datafusion_expr::{col, expr::Cast, lit, utils::expr_to_columns}, }; use super::*; #[test] fn test_split_conjunction() { let expr = col("a"); let result = split_conjunction(&expr); assert_eq!(result, vec![&expr]); } #[test] fn test_split_conjunction_two() { let expr = col("a").eq(lit(5)).and(col("b")); let expr1 = col("a").eq(lit(5)); let expr2 = col("b"); let result = split_conjunction(&expr); assert_eq!(result, vec![&expr1, &expr2]); } #[test] fn test_split_conjunction_alias() { let expr = col("a").eq(lit(5)).and(col("b").alias("the_alias")); let expr1 = col("a").eq(lit(5)); let expr2 = col("b"); // has no alias let result = split_conjunction(&expr); assert_eq!(result, vec![&expr1, &expr2]); } #[test] fn test_split_conjunction_or() { let expr = col("a").eq(lit(5)).or(col("b")); let result = split_conjunction(&expr); assert_eq!(result, vec![&expr]); } #[test] fn test_split_binary_owned() { let expr = col("a"); assert_eq!(split_binary_owned(expr.clone(), Operator::And), vec![expr]); } #[test] fn test_split_binary_owned_two() { assert_eq!( split_binary_owned(col("a").eq(lit(5)).and(col("b")), Operator::And), vec![col("a").eq(lit(5)), col("b")] ); } #[test] fn test_split_binary_owned_different_op() { let expr = col("a").eq(lit(5)).or(col("b")); assert_eq!( // expr is connected by OR, but pass in AND split_binary_owned(expr.clone(), Operator::And), vec![expr] ); } #[test] fn test_split_conjunction_owned() { let expr = col("a"); assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]); } #[test] fn test_split_conjunction_owned_two() { assert_eq!( split_conjunction_owned(col("a").eq(lit(5)).and(col("b"))), vec![col("a").eq(lit(5)), col("b")] ); } #[test] fn test_split_conjunction_owned_alias() { assert_eq!( split_conjunction_owned(col("a").eq(lit(5)).and(col("b").alias("the_alias"))), vec![ col("a").eq(lit(5)), // no alias on b col("b"), ] ); } #[test] fn test_conjunction_empty() { assert_eq!(conjunction(vec![]), None); } #[test] fn test_conjunction() { // `[A, B, C]` let expr = conjunction(vec![col("a"), col("b"), col("c")]); // --> `(A AND B) AND C` assert_eq!(expr, Some(col("a").and(col("b")).and(col("c")))); // which is different than `A AND (B AND C)` assert_ne!(expr, Some(col("a").and(col("b").and(col("c"))))); } #[test] fn test_disjunction_empty() { assert_eq!(disjunction(vec![]), None); } #[test] fn test_disjunction() { // `[A, B, C]` let expr = disjunction(vec![col("a"), col("b"), col("c")]); // --> `(A OR B) OR C` assert_eq!(expr, Some(col("a").or(col("b")).or(col("c")))); // which is different than `A OR (B OR C)` assert_ne!(expr, Some(col("a").or(col("b").or(col("c"))))); } #[test] fn test_split_conjunction_owned_or() { let expr = col("a").eq(lit(5)).or(col("b")); assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]); } #[test] fn test_collect_expr() -> Result<()> { let mut accum: HashSet = HashSet::new(); expr_to_columns( &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)), &mut accum, )?; expr_to_columns( &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)), &mut accum, )?; assert_eq!(1, accum.len()); assert!(accum.contains(&Column::from_name("a"))); Ok(()) } } ================================================ FILE: src/sql/optimizer.rs ================================================ // Declare optimizer modules pub mod decorrelate_where_exists; pub mod decorrelate_where_in; pub mod dynamic_partition_pruning; pub mod join_reorder; pub mod utils; use std::sync::Arc; use datafusion_python::{ datafusion_common::DataFusionError, datafusion_expr::LogicalPlan, datafusion_optimizer::{ eliminate_cross_join::EliminateCrossJoin, eliminate_limit::EliminateLimit, eliminate_outer_join::EliminateOuterJoin, eliminate_project::EliminateProjection, filter_null_join_keys::FilterNullJoinKeys, optimizer::{Optimizer, OptimizerRule}, push_down_filter::PushDownFilter, push_down_limit::PushDownLimit, push_down_projection::PushDownProjection, rewrite_disjunctive_predicate::RewriteDisjunctivePredicate, scalar_subquery_to_join::ScalarSubqueryToJoin, simplify_expressions::SimplifyExpressions, unwrap_cast_in_comparison::UnwrapCastInComparison, OptimizerContext, }, }; use decorrelate_where_exists::DecorrelateWhereExists; use decorrelate_where_in::DecorrelateWhereIn; use dynamic_partition_pruning::DynamicPartitionPruning; use join_reorder::JoinReorder; use log::{debug, trace}; /// Houses the optimization logic for Dask-SQL. This optimization controls the optimizations /// and their ordering in regards to their impact on the underlying `LogicalPlan` instance pub struct DaskSqlOptimizer { optimizer: Optimizer, } impl DaskSqlOptimizer { /// Creates a new instance of the DaskSqlOptimizer with all the DataFusion desired /// optimizers as well as any custom `OptimizerRule` trait impls that might be desired. pub fn new( fact_dimension_ratio: Option, max_fact_tables: Option, preserve_user_order: Option, filter_selectivity: Option, ) -> Self { debug!("Creating new instance of DaskSqlOptimizer"); let rules: Vec> = vec![ Arc::new(SimplifyExpressions::new()), Arc::new(UnwrapCastInComparison::new()), // Arc::new(ReplaceDistinctWithAggregate::new()), Arc::new(DecorrelateWhereExists::new()), Arc::new(DecorrelateWhereIn::new()), Arc::new(ScalarSubqueryToJoin::new()), //Arc::new(ExtractEquijoinPredicate::new()), // simplify expressions does not simplify expressions in subqueries, so we // run it again after running the optimizations that potentially converted // subqueries to joins Arc::new(SimplifyExpressions::new()), // Arc::new(MergeProjection::new()), Arc::new(RewriteDisjunctivePredicate::new()), // Arc::new(EliminateDuplicatedExpr::new()), // TODO: need to handle EmptyRelation for GPU cases // Arc::new(EliminateFilter::new()), Arc::new(EliminateCrossJoin::new()), // Arc::new(CommonSubexprEliminate::new()), Arc::new(EliminateLimit::new()), // Arc::new(PropagateEmptyRelation::new()), Arc::new(FilterNullJoinKeys::default()), Arc::new(EliminateOuterJoin::new()), // Filters can't be pushed down past Limits, we should do PushDownFilter after PushDownLimit Arc::new(PushDownLimit::new()), Arc::new(PushDownFilter::new()), // Arc::new(SingleDistinctToGroupBy::new()), // Dask-SQL specific optimizations Arc::new(JoinReorder::new( fact_dimension_ratio, max_fact_tables, preserve_user_order, filter_selectivity, )), // The previous optimizations added expressions and projections, // that might benefit from the following rules Arc::new(SimplifyExpressions::new()), Arc::new(UnwrapCastInComparison::new()), // Arc::new(CommonSubexprEliminate::new()), Arc::new(PushDownProjection::new()), Arc::new(EliminateProjection::new()), // PushDownProjection can pushdown Projections through Limits, do PushDownLimit again. Arc::new(PushDownLimit::new()), ]; Self { optimizer: Optimizer::with_rules(rules), } } // Create a separate instance of this optimization rule, since we want to ensure that it only // runs one time pub fn dynamic_partition_pruner(fact_dimension_ratio: Option) -> Self { let rule: Vec> = vec![Arc::new( DynamicPartitionPruning::new(fact_dimension_ratio.unwrap_or(0.3)), )]; Self { optimizer: Optimizer::with_rules(rule), } } /// Iterates through the configured `OptimizerRule`(s) to transform the input `LogicalPlan` /// to its final optimized form pub(crate) fn optimize(&self, plan: LogicalPlan) -> Result { let config = OptimizerContext::new(); self.optimizer.optimize(&plan, &config, Self::observe) } /// Iterates once through the configured `OptimizerRule`(s) to transform the input `LogicalPlan` /// to its final optimized form pub(crate) fn optimize_once(&self, plan: LogicalPlan) -> Result { let mut config = OptimizerContext::new(); config = OptimizerContext::with_max_passes(config, 1); self.optimizer.optimize(&plan, &config, Self::observe) } fn observe(optimized_plan: &LogicalPlan, optimization: &dyn OptimizerRule) { trace!( "== AFTER APPLYING RULE {} ==\n{}\n", optimization.name(), optimized_plan.display_indent() ); } } #[cfg(test)] mod tests { use std::{any::Any, collections::HashMap, sync::Arc}; use datafusion_python::{ datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}, datafusion_common::{config::ConfigOptions, DataFusionError, Result}, datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource}, datafusion_sql::{ planner::{ContextProvider, SqlToRel}, sqlparser::{ast::Statement, parser::Parser}, TableReference, }, }; use crate::{dialect::DaskDialect, sql::optimizer::DaskSqlOptimizer}; #[test] fn subquery_filter_with_cast() -> Result<()> { // regression test for https://github.com/apache/arrow-datafusion/issues/3760 let sql = "SELECT col_int32 FROM test \ WHERE col_int32 > (\ SELECT AVG(col_int32) FROM test \ WHERE col_utf8 BETWEEN '2002-05-08' \ AND (cast('2002-05-08' as date) + interval '5 days')\ )"; let plan = test_sql(sql)?; assert!(format!("{:?}", plan).contains(r#"<= Date32("11820")"#)); Ok(()) } fn test_sql(sql: &str) -> Result { // parse the SQL let dialect = DaskDialect {}; let ast: Vec = Parser::parse_sql(&dialect, sql).unwrap(); let statement = &ast[0]; // create a logical query plan let schema_provider = MySchemaProvider::new(); let sql_to_rel = SqlToRel::new(&schema_provider); let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); // optimize the logical plan let optimizer = DaskSqlOptimizer::new(None, None, None, None); optimizer.optimize(plan) } struct MySchemaProvider { options: ConfigOptions, } impl MySchemaProvider { fn new() -> Self { Self { options: ConfigOptions::default(), } } } impl ContextProvider for MySchemaProvider { fn options(&self) -> &ConfigOptions { &self.options } fn get_table_provider( &self, name: TableReference, ) -> datafusion_python::datafusion_common::Result> { let table_name = name.table(); if table_name.starts_with("test") { let schema = Schema::new_with_metadata( vec![ Field::new("col_int32", DataType::Int32, true), Field::new("col_uint32", DataType::UInt32, true), Field::new("col_utf8", DataType::Utf8, true), Field::new("col_date32", DataType::Date32, true), Field::new("col_date64", DataType::Date64, true), ], HashMap::new(), ); Ok(Arc::new(MyTableSource { schema: Arc::new(schema), })) } else { Err(DataFusionError::Plan("table does not exist".to_string())) } } fn get_function_meta(&self, _name: &str) -> Option> { None } fn get_aggregate_meta(&self, _name: &str) -> Option> { None } fn get_variable_type(&self, _variable_names: &[String]) -> Option { None } fn get_window_meta( &self, _name: &str, ) -> Option> { None } } struct MyTableSource { schema: SchemaRef, } impl TableSource for MyTableSource { fn as_any(&self) -> &dyn Any { self } fn schema(&self) -> SchemaRef { self.schema.clone() } } } ================================================ FILE: src/sql/parser_utils.rs ================================================ use datafusion_python::datafusion_sql::sqlparser::{ast::ObjectName, parser::ParserError}; pub struct DaskParserUtils; impl DaskParserUtils { /// Retrieves the schema and object name from a `ObjectName` instance pub fn elements_from_object_name( obj_name: &ObjectName, ) -> Result<(Option, String), ParserError> { let identities: Vec = obj_name.0.iter().map(|f| f.value.clone()).collect(); match identities.len() { 1 => Ok((None, identities[0].clone())), 2 => Ok((Some(identities[0].clone()), identities[1].clone())), _ => Err(ParserError::ParserError( "TableFactor name only supports 1 or 2 elements".to_string(), )), } } } ================================================ FILE: src/sql/preoptimizer.rs ================================================ use std::collections::HashMap; use datafusion_python::{ datafusion::arrow::datatypes::{DataType, TimeUnit}, datafusion_common::{Column, DFField, ScalarValue}, datafusion_expr::{logical_plan::Filter, BinaryExpr, Expr, LogicalPlan, Operator}, }; // Sometimes, DataFusion's optimizer will raise an OptimizationException before we even get the // chance to correct it anywhere. In these cases, we can still modify the LogicalPlan as an // optimizer rule would, however we have to run it independently, separately of DataFusion's // optimization framework. Ideally, these "pre-optimization" rules aren't performing any complex // logic, but rather "pre-processing" the LogicalPlan for the optimizer. For example, the // datetime_coercion preoptimizer rule fixes a bug involving Timestamp-Int operations. // Helper function for datetime_coercion rule, which returns a vector of columns and literals // involved in a (possibly nested) BinaryExpr mathematical expression, and the mathematical // BinaryExpr itself fn extract_columns_and_literals(expr: &Expr) -> Vec<(Vec, Expr)> { let mut result = Vec::new(); if let Expr::BinaryExpr(b) = expr { let left = *b.left.clone(); let right = *b.right.clone(); if let Operator::Plus | Operator::Minus | Operator::Multiply | Operator::Divide | Operator::Modulo = &b.op { let mut operands = Vec::new(); if let Expr::Column(_) | Expr::Literal(_) = left.clone() { operands.push(left); } else { let vector_of_vectors = extract_columns_and_literals(&left); let mut flattened = Vec::new(); for vector in vector_of_vectors { flattened.extend(vector.0); } operands.append(&mut flattened); } if let Expr::Column(_) | Expr::Literal(_) = right.clone() { operands.push(right); } else { let vector_of_vectors = extract_columns_and_literals(&right); let mut flattened = Vec::new(); for vector in vector_of_vectors { flattened.extend(vector.0); } operands.append(&mut flattened); } result.push((operands, expr.clone())); } else { if let Expr::BinaryExpr(_) = left { result.append(&mut extract_columns_and_literals(&left)); } if let Expr::BinaryExpr(_) = right { result.append(&mut extract_columns_and_literals(&right)); } } } result } // Helper function for datetime_coercion rule, which uses a LogicalPlan's schema to obtain the // datatype of a desired column fn find_data_type(column: Column, fields: Vec) -> Option { for field in fields { if let Some(qualifier) = field.qualifier() { if column.relation.is_some() && qualifier.table() == column.relation.clone().unwrap().table() && field.field().name() == &column.name { return Some(field.field().data_type().clone()); } } } None } // Helper function for datetime_coercion rule, which, given a BinaryExpr and a HashMap in which the // key represents a Literal and the value represents a Literal to replace the key with, returns the // modified BinaryExpr fn replace_literals(expr: Expr, replacements: HashMap) -> Expr { match expr { Expr::Literal(l) => { if let Some(new_literal) = replacements.get(&Expr::Literal(l.clone())) { new_literal.clone() } else { Expr::Literal(l) } } Expr::BinaryExpr(b) => { let left = replace_literals(*b.left, replacements.clone()); let right = replace_literals(*b.right, replacements); Expr::BinaryExpr(BinaryExpr { left: Box::new(left), op: b.op, right: Box::new(right), }) } _ => expr, } } // Helper function for datetime_coercion rule, which, given a BinaryExpr expr and a HashMap in // which the key represents a BinaryExpr and the value represents a BinaryExpr to replace the key // with, returns the modified expr fn replace_binary_exprs(expr: Expr, replacements: HashMap) -> Expr { match expr { Expr::BinaryExpr(b) => { if let Some(new_expr) = replacements.get(&Expr::BinaryExpr(b.clone())) { new_expr.clone() } else { let left = replace_binary_exprs(*b.left, replacements.clone()); let right = replace_binary_exprs(*b.right, replacements); Expr::BinaryExpr(BinaryExpr { left: Box::new(left), op: b.op, right: Box::new(right), }) } } _ => expr, } } // Preoptimization rule which detects when the user is trying to perform a binary operation on a // datetime and an integer, then converts the integer to a IntervalMonthDayNano. For example, if we // have a date_col + 5, we assume that we are adding 5 days to the date_col pub fn datetime_coercion(plan: &LogicalPlan) -> Option { match plan { LogicalPlan::Filter(f) => { let filter_expr = f.predicate.clone(); let columns_and_literals = extract_columns_and_literals(&filter_expr); let mut days_to_nanoseconds: Vec<(Expr, HashMap)> = Vec::new(); for vector in columns_and_literals.iter() { // Detect whether a timestamp is involved in the operation let mut is_timestamp_operation = false; for item in vector.0.iter() { if let Expr::Column(column) = item { if let Some(DataType::Timestamp(TimeUnit::Nanosecond, _)) = find_data_type(column.clone(), plan.schema().fields().clone()) { is_timestamp_operation = true; } } } // Convert an integer to an IntervalMonthDayNano if is_timestamp_operation { let mut find_replace = HashMap::new(); for item in vector.0.iter() { if let Expr::Literal(ScalarValue::Int64(i)) = item { let ns = i.unwrap() as i128 * 18446744073709552000; find_replace.insert( Expr::Literal(ScalarValue::Int64(*i)), Expr::Literal(ScalarValue::IntervalMonthDayNano(Some(ns))), ); } } days_to_nanoseconds.push((vector.1.clone(), find_replace)); } } let mut binary_exprs = HashMap::new(); for replacements in days_to_nanoseconds.iter() { binary_exprs.insert( replacements.0.clone(), replace_literals(replacements.0.clone(), replacements.1.clone()), ); } let new_filter = replace_binary_exprs(filter_expr, binary_exprs); Some(LogicalPlan::Filter( Filter::try_new(new_filter, f.input.clone()).unwrap(), )) } _ => optimize_children(plan.clone()), } } // Function used to iterate through a LogicalPlan and update it accordingly fn optimize_children(existing_plan: LogicalPlan) -> Option { let plan = existing_plan.clone(); let new_exprs = plan.expressions(); let mut new_inputs = Vec::with_capacity(plan.inputs().len()); let mut plan_is_changed = false; for input in plan.inputs() { // Since datetime_coercion is the only preoptimizer rule that we have at the moment, we // hardcode it here. If additional preoptimizer rules are added in the future, this can be // modified let new_input = datetime_coercion(input); plan_is_changed = plan_is_changed || new_input.is_some(); new_inputs.push(new_input.unwrap_or_else(|| input.clone())) } if plan_is_changed { Some(plan.with_new_exprs(new_exprs, &new_inputs).ok()?) } else { Some(existing_plan) } } ================================================ FILE: src/sql/schema.rs ================================================ use std::collections::HashMap; use ::std::sync::{Arc, Mutex}; use pyo3::prelude::*; use super::types::PyDataType; use crate::sql::{function::DaskFunction, table}; #[pyclass(name = "DaskSchema", module = "dask_sql", subclass)] #[derive(Debug, Clone)] pub struct DaskSchema { #[pyo3(get, set)] pub(crate) name: String, pub(crate) tables: HashMap, pub(crate) functions: HashMap>>, } #[pymethods] impl DaskSchema { #[new] pub fn new(schema_name: &str) -> Self { Self { name: schema_name.to_owned(), tables: HashMap::new(), functions: HashMap::new(), } } pub fn add_table(&mut self, table: table::DaskTable) { self.tables.insert(table.table_name.clone(), table); } pub fn add_or_overload_function( &mut self, name: String, input_types: Vec, return_type: PyDataType, aggregation: bool, ) { self.functions .entry(name.clone()) .and_modify(|e| { (*e).lock() .unwrap() .add_type_mapping(input_types.clone(), return_type.clone()); }) .or_insert_with(|| { Arc::new(Mutex::new(DaskFunction::new( name, input_types, return_type, aggregation, ))) }); } } ================================================ FILE: src/sql/statement.rs ================================================ use pyo3::prelude::*; use crate::parser::DaskStatement; #[pyclass(name = "Statement", module = "dask_sql", subclass)] #[derive(Debug, Clone)] pub struct PyStatement { pub statement: DaskStatement, } impl From for DaskStatement { fn from(statement: PyStatement) -> DaskStatement { statement.statement } } impl From for PyStatement { fn from(statement: DaskStatement) -> PyStatement { PyStatement { statement } } } impl PyStatement { pub fn new(statement: DaskStatement) -> Self { Self { statement } } } ================================================ FILE: src/sql/table.rs ================================================ use std::{any::Any, sync::Arc}; use async_trait::async_trait; use datafusion_python::{ datafusion::arrow::datatypes::{DataType, Fields, SchemaRef}, datafusion_common::DFField, datafusion_expr::{Expr, LogicalPlan, TableProviderFilterPushDown, TableSource}, datafusion_optimizer::utils::split_conjunction, datafusion_sql::TableReference, }; use pyo3::prelude::*; use super::logical::{create_table::CreateTablePlanNode, predict_model::PredictModelPlanNode}; use crate::{ error::DaskPlannerError, sql::{ logical, types::{ rel_data_type::RelDataType, rel_data_type_field::RelDataTypeField, DaskTypeMap, SqlTypeName, }, }, }; /// DaskTable wrapper that is compatible with DataFusion logical query plans pub struct DaskTableSource { schema: SchemaRef, statistics: Option, filepath: Option, } impl DaskTableSource { /// Initialize a new `EmptyTable` from a schema pub fn new( schema: SchemaRef, statistics: Option, filepath: Option, ) -> Self { Self { schema, statistics, filepath, } } /// Access optional statistics associated with this table source pub fn statistics(&self) -> Option<&DaskStatistics> { self.statistics.as_ref() } /// Access optional filepath associated with this table source pub fn filepath(&self) -> Option<&String> { self.filepath.as_ref() } } /// Implement TableSource, used in the logical query plan and in logical query optimizations #[async_trait] impl TableSource for DaskTableSource { fn as_any(&self) -> &dyn Any { self } fn schema(&self) -> SchemaRef { self.schema.clone() } fn supports_filter_pushdown( &self, filter: &Expr, ) -> datafusion_python::datafusion_common::Result { let filters = split_conjunction(filter); if filters.iter().all(|f| is_supported_push_down_expr(f)) { // Push down filters to the tablescan operation if all are supported Ok(TableProviderFilterPushDown::Exact) } else if filters.iter().any(|f| is_supported_push_down_expr(f)) { // Partially apply the filter in the TableScan but retain // the Filter operator in the plan as well Ok(TableProviderFilterPushDown::Inexact) } else { Ok(TableProviderFilterPushDown::Unsupported) } } } fn is_supported_push_down_expr(_expr: &Expr) -> bool { // For now we support all kinds of expr's at this level true } #[pyclass(name = "DaskStatistics", module = "dask_sql", subclass)] #[derive(Debug, Clone)] pub struct DaskStatistics { row_count: f64, } #[pymethods] impl DaskStatistics { #[new] pub fn new(row_count: f64) -> Self { Self { row_count } } #[pyo3(name = "getRowCount")] pub fn get_row_count(&self) -> f64 { self.row_count } } #[pyclass(name = "DaskTable", module = "dask_sql", subclass)] #[derive(Debug, Clone)] pub struct DaskTable { pub(crate) schema_name: Option, pub(crate) table_name: String, pub(crate) statistics: DaskStatistics, pub(crate) columns: Vec<(String, DaskTypeMap)>, pub(crate) filepath: Option, } #[pymethods] impl DaskTable { #[new] pub fn new( schema_name: &str, table_name: &str, row_count: f64, columns: Option>, filepath: Option, ) -> Self { Self { schema_name: Some(schema_name.to_owned()), table_name: table_name.to_owned(), statistics: DaskStatistics::new(row_count), columns: columns.unwrap_or_default(), filepath, } } // TODO: Really wish we could accept a SqlTypeName instance here instead of a String for `column_type` .... #[pyo3(name = "add_column")] pub fn add_column(&mut self, column_name: &str, type_map: DaskTypeMap) { self.columns.push((column_name.to_owned(), type_map)); } #[pyo3(name = "getSchema")] pub fn get_schema(&self) -> PyResult> { Ok(self.schema_name.clone()) } #[pyo3(name = "getTableName")] pub fn get_table_name(&self) -> PyResult { Ok(self.table_name.clone()) } #[pyo3(name = "getQualifiedName")] pub fn qualified_name(&self, plan: logical::PyLogicalPlan) -> Vec { let mut qualified_name = match &self.schema_name { Some(schema_name) => vec![schema_name.clone()], None => vec![], }; match plan.original_plan { LogicalPlan::TableScan(table_scan) => { qualified_name.push(table_scan.table_name.to_string()); } _ => { qualified_name.push(self.table_name.clone()); } } qualified_name } #[pyo3(name = "getRowType")] pub fn row_type(&self) -> RelDataType { let mut fields: Vec = Vec::new(); for (name, data_type) in &self.columns { fields.push(RelDataTypeField::new(name.as_str(), data_type.clone(), 255)); } RelDataType::new(false, fields) } } /// Traverses the logical plan to locate the Table associated with the query pub(crate) fn table_from_logical_plan( plan: &LogicalPlan, ) -> Result, DaskPlannerError> { match plan { LogicalPlan::Projection(projection) => table_from_logical_plan(&projection.input), LogicalPlan::Filter(filter) => table_from_logical_plan(&filter.input), LogicalPlan::TableScan(table_scan) => { // Get the TableProvider for this Table instance let tbl_provider: Arc = table_scan.source.clone(); let tbl_schema: SchemaRef = tbl_provider.schema(); let fields: &Fields = tbl_schema.fields(); let mut cols: Vec<(String, DaskTypeMap)> = Vec::new(); for field in fields { let data_type: &DataType = field.data_type(); cols.push(( String::from(field.name()), DaskTypeMap::from( SqlTypeName::from_arrow(data_type)?, data_type.clone().into(), ), )); } let table_ref: TableReference = table_scan.table_name.clone(); let (schema, tbl) = match table_ref { TableReference::Bare { table } => ("".to_string(), table), TableReference::Partial { schema, table } => (schema.to_string(), table), TableReference::Full { catalog: _, schema, table, } => (schema.to_string(), table), }; Ok(Some(DaskTable { schema_name: Some(schema), table_name: String::from(tbl), statistics: DaskStatistics { row_count: 0.0 }, columns: cols, filepath: None, })) } LogicalPlan::Join(join) => { // TODO: Don't always hardcode the left table_from_logical_plan(&join.left) } LogicalPlan::Aggregate(agg) => table_from_logical_plan(&agg.input), LogicalPlan::SubqueryAlias(alias) => table_from_logical_plan(&alias.input), LogicalPlan::EmptyRelation(empty_relation) => { let fields: &Vec = empty_relation.schema.fields(); let mut cols: Vec<(String, DaskTypeMap)> = Vec::new(); for field in fields { let data_type: &DataType = field.data_type(); cols.push(( String::from(field.name()), DaskTypeMap::from( SqlTypeName::from_arrow(data_type)?, data_type.clone().into(), ), )); } Ok(Some(DaskTable { schema_name: Some(String::from("EmptySchema")), table_name: String::from("EmptyRelation"), statistics: DaskStatistics { row_count: 0.0 }, columns: cols, filepath: None, })) } LogicalPlan::Extension(ex) => { let node = ex.node.as_any(); if let Some(e) = node.downcast_ref::() { Ok(Some(DaskTable { schema_name: e.schema_name.clone(), table_name: e.table_name.clone(), statistics: DaskStatistics { row_count: 0.0 }, columns: vec![], filepath: None, })) } else if let Some(e) = node.downcast_ref::() { Ok(Some(DaskTable { schema_name: e.schema_name.clone(), table_name: e.model_name.clone(), statistics: DaskStatistics { row_count: 0.0 }, columns: vec![], filepath: None, })) } else { Err(DaskPlannerError::Internal(format!( "table_from_logical_plan: unimplemented LogicalPlan type {plan:?} encountered" ))) } } _ => Err(DaskPlannerError::Internal(format!( "table_from_logical_plan: unimplemented LogicalPlan type {plan:?} encountered" ))), } } ================================================ FILE: src/sql/types/rel_data_type.rs ================================================ use std::collections::HashMap; use pyo3::prelude::*; use crate::sql::{exceptions::py_runtime_err, types::rel_data_type_field::RelDataTypeField}; const PRECISION_NOT_SPECIFIED: i32 = i32::MIN; const SCALE_NOT_SPECIFIED: i32 = -1; /// RelDataType represents the type of a scalar expression or entire row returned from a relational expression. #[pyclass(name = "RelDataType", module = "dask_sql", subclass)] #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct RelDataType { nullable: bool, field_list: Vec, } /// RelDataType represents the type of a scalar expression or entire row returned from a relational expression. #[pymethods] impl RelDataType { #[new] pub fn new(nullable: bool, fields: Vec) -> Self { Self { nullable, field_list: fields, } } /// Looks up a field by name. /// /// # Arguments /// /// * `field_name` - A String containing the name of the field to find /// * `case_sensitive` - True if column name matching should be case sensitive and false otherwise #[pyo3(name = "getField")] pub fn field(&self, field_name: &str, case_sensitive: bool) -> PyResult { let field_map: HashMap = self.field_map(); if case_sensitive && !field_map.is_empty() { Ok(field_map.get(field_name).unwrap().clone()) } else { for field in &self.field_list { if (case_sensitive && field.name().eq(field_name)) || (!case_sensitive && field.name().eq_ignore_ascii_case(field_name)) { return Ok(field.clone()); } } // TODO: Throw a proper error here Err(py_runtime_err(format!( "Unable to find RelDataTypeField with name {field_name:?} in the RelDataType field_list" ))) } } /// Returns a map from field names to fields. /// /// # Notes /// /// * If several fields have the same name, the map contains the first. #[pyo3(name = "getFieldMap")] pub fn field_map(&self) -> HashMap { let mut fields: HashMap = HashMap::new(); for field in &self.field_list { fields.insert(String::from(field.name()), field.clone()); } fields } /// Gets the fields in a struct type. The field count is equal to the size of the returned list. #[pyo3(name = "getFieldList")] pub fn field_list(&self) -> Vec { self.field_list.clone() } /// Returns the names of all of the columns in a given DaskTable #[pyo3(name = "getFieldNames")] pub fn field_names(&self) -> Vec { let mut field_names: Vec = Vec::new(); for field in &self.field_list { field_names.push(field.qualified_name()); } field_names } /// Returns the number of fields in a struct type. #[pyo3(name = "getFieldCount")] pub fn field_count(&self) -> usize { self.field_list.len() } #[pyo3(name = "isStruct")] pub fn is_struct(&self) -> bool { !self.field_list.is_empty() } /// Queries whether this type allows null values. #[pyo3(name = "isNullable")] pub fn is_nullable(&self) -> bool { self.nullable } #[pyo3(name = "getPrecision")] pub fn precision(&self) -> i32 { PRECISION_NOT_SPECIFIED } #[pyo3(name = "getScale")] pub fn scale(&self) -> i32 { SCALE_NOT_SPECIFIED } } ================================================ FILE: src/sql/types/rel_data_type_field.rs ================================================ use std::fmt; use datafusion_python::{ datafusion_common::{DFField, DFSchema}, datafusion_sql::TableReference, }; use pyo3::prelude::*; use crate::{ error::Result, sql::types::{DaskTypeMap, SqlTypeName}, }; /// RelDataTypeField represents the definition of a field in a structured RelDataType. #[pyclass(name = "RelDataTypeField", module = "dask_sql", subclass)] #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct RelDataTypeField { qualifier: Option, name: String, data_type: DaskTypeMap, index: usize, } // Functions that should not be presented to Python are placed here impl RelDataTypeField { pub fn from(field: &DFField, schema: &DFSchema) -> Result { let qualifier: Option<&TableReference> = field.qualifier(); Ok(RelDataTypeField { qualifier: qualifier.map(|qualifier| qualifier.to_string()), name: field.name().clone(), data_type: DaskTypeMap { sql_type: SqlTypeName::from_arrow(field.data_type())?, data_type: field.data_type().clone().into(), }, index: schema .index_of_column_by_name(qualifier, field.name())? .unwrap(), }) } } #[pymethods] impl RelDataTypeField { #[new] pub fn new(name: &str, type_map: DaskTypeMap, index: usize) -> Self { Self { qualifier: None, name: name.to_owned(), data_type: type_map, index, } } #[pyo3(name = "getQualifier")] pub fn qualifier(&self) -> Option { self.qualifier.clone() } #[pyo3(name = "getName")] pub fn name(&self) -> &str { &self.name } #[pyo3(name = "getQualifiedName")] pub fn qualified_name(&self) -> String { match &self.qualifier() { Some(qualifier) => format!("{}.{}", &qualifier, self.name()), None => self.name().to_string(), } } #[pyo3(name = "getIndex")] pub fn index(&self) -> usize { self.index } #[pyo3(name = "getType")] pub fn data_type(&self) -> DaskTypeMap { self.data_type.clone() } /// Since this logic is being ported from Java getKey is synonymous with getName. /// Alas it is used in certain places so it is implemented here to allow other /// places in the code base to not have to change. #[pyo3(name = "getKey")] pub fn get_key(&self) -> &str { self.name() } /// Since this logic is being ported from Java getValue is synonymous with getType. /// Alas it is used in certain places so it is implemented here to allow other /// places in the code base to not have to change. #[pyo3(name = "getValue")] pub fn get_value(&self) -> DaskTypeMap { self.data_type() } #[pyo3(name = "setValue")] pub fn set_value(&mut self, data_type: DaskTypeMap) { self.data_type = data_type } // TODO: Uncomment after implementing in RelDataType // #[pyo3(name = "isDynamicStar")] // pub fn is_dynamic_star(&self) -> bool { // self.data_type.getSqlTypeName() == SqlTypeName.DYNAMIC_STAR // } } impl fmt::Display for RelDataTypeField { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { fmt.write_str("Field: ")?; fmt.write_str(&self.name)?; fmt.write_str(" - Index: ")?; fmt.write_str(&self.index.to_string())?; // TODO: Uncomment this after implementing the Display trait in RelDataType // fmt.write_str(" - DataType: ")?; // fmt.write_str(self.data_type.to_string())?; Ok(()) } } ================================================ FILE: src/sql/types.rs ================================================ pub mod rel_data_type; pub mod rel_data_type_field; use std::sync::Arc; use datafusion_python::{ datafusion::arrow::datatypes::{DataType, IntervalUnit, TimeUnit}, datafusion_sql::sqlparser::{ast::DataType as SQLType, parser::Parser, tokenizer::Tokenizer}, }; use pyo3::{prelude::*, types::PyDict}; use crate::{dialect::DaskDialect, error::DaskPlannerError, sql::exceptions::py_type_err}; #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] #[pyclass(name = "RexType", module = "dask_sql")] pub enum RexType { Alias, Literal, Call, Reference, ScalarSubquery, Other, } #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] #[pyclass(name = "DaskTypeMap", module = "dask_sql", subclass)] /// Represents a Python Data Type. This is needed instead of simple /// Enum instances because PyO3 can only support unit variants as /// of version 0.16 which means Enums like `DataType::TIMESTAMP_WITH_LOCAL_TIME_ZONE` /// which generally hold `unit` and `tz` information are unable to /// do that so data is lost. This struct aims to solve that issue /// by taking the type Enum from Python and some optional extra /// parameters that can be used to properly create those DataType /// instances in Rust. pub struct DaskTypeMap { sql_type: SqlTypeName, data_type: PyDataType, } /// Functions not exposed to Python impl DaskTypeMap { pub fn from(sql_type: SqlTypeName, data_type: PyDataType) -> Self { DaskTypeMap { sql_type, data_type, } } } #[pymethods] impl DaskTypeMap { #[new] #[pyo3(signature = (sql_type, **py_kwargs))] fn new(sql_type: SqlTypeName, py_kwargs: Option<&PyDict>) -> PyResult { let d_type: DataType = match sql_type { SqlTypeName::TIMESTAMP_WITH_LOCAL_TIME_ZONE => { let (unit, tz) = match py_kwargs { Some(dict) => { let tz: Option> = match dict.get_item("tz") { Some(e) => { let res: PyResult = e.extract(); Some(Arc::from(>::as_ref( &res.unwrap(), ))) } None => None, }; let unit: TimeUnit = match dict.get_item("unit") { Some(e) => { let res: PyResult<&str> = e.extract(); match res.unwrap() { "Second" => TimeUnit::Second, "Millisecond" => TimeUnit::Millisecond, "Microsecond" => TimeUnit::Microsecond, "Nanosecond" => TimeUnit::Nanosecond, _ => TimeUnit::Nanosecond, } } // Default to Nanosecond which is common if not present None => TimeUnit::Nanosecond, }; (unit, tz) } // Default to Nanosecond and None for tz which is common if not present None => (TimeUnit::Nanosecond, None), }; DataType::Timestamp(unit, tz) } SqlTypeName::TIMESTAMP => { let (unit, tz) = match py_kwargs { Some(dict) => { let tz: Option> = match dict.get_item("tz") { Some(e) => { let res: PyResult = e.extract(); Some(Arc::from(>::as_ref( &res.unwrap(), ))) } None => None, }; let unit: TimeUnit = match dict.get_item("unit") { Some(e) => { let res: PyResult<&str> = e.extract(); match res.unwrap() { "Second" => TimeUnit::Second, "Millisecond" => TimeUnit::Millisecond, "Microsecond" => TimeUnit::Microsecond, "Nanosecond" => TimeUnit::Nanosecond, _ => TimeUnit::Nanosecond, } } // Default to Nanosecond which is common if not present None => TimeUnit::Nanosecond, }; (unit, tz) } // Default to Nanosecond and None for tz which is common if not present None => (TimeUnit::Nanosecond, None), }; DataType::Timestamp(unit, tz) } SqlTypeName::DECIMAL => { let (precision, scale) = match py_kwargs { Some(dict) => { let precision: u8 = match dict.get_item("precision") { Some(e) => { let res: PyResult = e.extract(); res.unwrap() } None => 38, }; let scale: i8 = match dict.get_item("scale") { Some(e) => { let res: PyResult = e.extract(); res.unwrap() } None => 0, }; (precision, scale) } None => (38, 10), }; DataType::Decimal128(precision, scale) } _ => sql_type.to_arrow()?, }; Ok(DaskTypeMap { sql_type, data_type: d_type.into(), }) } fn __str__(&self) -> String { format!("{:?}", self.sql_type) } #[pyo3(name = "getSqlType")] pub fn sql_type(&self) -> SqlTypeName { self.sql_type.clone() } #[pyo3(name = "getDataType")] pub fn data_type(&self) -> PyDataType { self.data_type.clone() } } #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] #[pyclass(name = "PyDataType", module = "dask_sql", subclass)] pub struct PyDataType { data_type: DataType, } #[pymethods] impl PyDataType { /// Gets the precision/scale represented by the PyDataType's decimal datatype #[pyo3(name = "getPrecisionScale")] pub fn get_precision_scale(&self) -> PyResult<(u8, i8)> { Ok(match &self.data_type { DataType::Decimal128(precision, scale) | DataType::Decimal256(precision, scale) => { (*precision, *scale) } _ => { return Err(py_type_err(format!( "Catch all triggered in get_precision_scale, {:?}", &self.data_type ))) } }) } } impl From for DataType { fn from(data_type: PyDataType) -> DataType { data_type.data_type } } impl From for PyDataType { fn from(data_type: DataType) -> PyDataType { PyDataType { data_type } } } /// Enumeration of the type names which can be used to construct a SQL type. Since /// several SQL types do not exist as Rust types and also because the Enum /// `SqlTypeName` is already used in the Python Dask-SQL code base this enum is used /// in place of just using the built-in Rust types. #[allow(non_camel_case_types)] #[allow(clippy::upper_case_acronyms)] #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] #[pyclass(name = "SqlTypeName", module = "dask_sql")] pub enum SqlTypeName { ANY, ARRAY, BIGINT, BINARY, BOOLEAN, CHAR, COLUMN_LIST, CURSOR, DATE, DECIMAL, DISTINCT, DOUBLE, DYNAMIC_STAR, FLOAT, GEOMETRY, INTEGER, INTERVAL, INTERVAL_DAY, INTERVAL_DAY_HOUR, INTERVAL_DAY_MINUTE, INTERVAL_DAY_SECOND, INTERVAL_HOUR, INTERVAL_HOUR_MINUTE, INTERVAL_HOUR_SECOND, INTERVAL_MINUTE, INTERVAL_MINUTE_SECOND, INTERVAL_MONTH, INTERVAL_MONTH_DAY_NANOSECOND, INTERVAL_SECOND, INTERVAL_YEAR, INTERVAL_YEAR_MONTH, MAP, MULTISET, NULL, OTHER, REAL, ROW, SARG, SMALLINT, STRUCTURED, SYMBOL, TIME, TIME_WITH_LOCAL_TIME_ZONE, TIMESTAMP, TIMESTAMP_WITH_LOCAL_TIME_ZONE, TINYINT, UNKNOWN, VARBINARY, VARCHAR, } impl SqlTypeName { pub fn to_arrow(&self) -> Result { match self { SqlTypeName::NULL => Ok(DataType::Null), SqlTypeName::BOOLEAN => Ok(DataType::Boolean), SqlTypeName::TINYINT => Ok(DataType::Int8), SqlTypeName::SMALLINT => Ok(DataType::Int16), SqlTypeName::INTEGER => Ok(DataType::Int32), SqlTypeName::BIGINT => Ok(DataType::Int64), SqlTypeName::REAL => Ok(DataType::Float16), SqlTypeName::FLOAT => Ok(DataType::Float32), SqlTypeName::DOUBLE => Ok(DataType::Float64), SqlTypeName::DATE => Ok(DataType::Date64), SqlTypeName::VARCHAR => Ok(DataType::Utf8), _ => Err(DaskPlannerError::Internal(format!( "Cannot determine Arrow type for Dask SQL type '{self:?}'" ))), } } pub fn from_arrow(arrow_type: &DataType) -> Result { match arrow_type { DataType::Null => Ok(SqlTypeName::NULL), DataType::Boolean => Ok(SqlTypeName::BOOLEAN), DataType::Int8 => Ok(SqlTypeName::TINYINT), DataType::Int16 => Ok(SqlTypeName::SMALLINT), DataType::Int32 => Ok(SqlTypeName::INTEGER), DataType::Int64 => Ok(SqlTypeName::BIGINT), DataType::UInt8 => Ok(SqlTypeName::TINYINT), DataType::UInt16 => Ok(SqlTypeName::SMALLINT), DataType::UInt32 => Ok(SqlTypeName::INTEGER), DataType::UInt64 => Ok(SqlTypeName::BIGINT), DataType::Float16 => Ok(SqlTypeName::REAL), DataType::Float32 => Ok(SqlTypeName::FLOAT), DataType::Float64 => Ok(SqlTypeName::DOUBLE), DataType::Time32(_) | DataType::Time64(_) => Ok(SqlTypeName::TIME), DataType::Timestamp(_unit, tz) => match tz { Some(_) => Ok(SqlTypeName::TIMESTAMP_WITH_LOCAL_TIME_ZONE), None => Ok(SqlTypeName::TIMESTAMP), }, DataType::Date32 => Ok(SqlTypeName::DATE), DataType::Date64 => Ok(SqlTypeName::DATE), DataType::Interval(unit) => match unit { IntervalUnit::DayTime => Ok(SqlTypeName::INTERVAL_DAY), IntervalUnit::YearMonth => Ok(SqlTypeName::INTERVAL_YEAR_MONTH), IntervalUnit::MonthDayNano => Ok(SqlTypeName::INTERVAL_MONTH_DAY_NANOSECOND), }, DataType::Binary => Ok(SqlTypeName::BINARY), DataType::FixedSizeBinary(_size) => Ok(SqlTypeName::VARBINARY), DataType::Utf8 => Ok(SqlTypeName::CHAR), DataType::LargeUtf8 => Ok(SqlTypeName::VARCHAR), DataType::Struct(_fields) => Ok(SqlTypeName::STRUCTURED), DataType::Decimal128(_precision, _scale) => Ok(SqlTypeName::DECIMAL), DataType::Decimal256(_precision, _scale) => Ok(SqlTypeName::DECIMAL), DataType::Map(_field, _bool) => Ok(SqlTypeName::MAP), _ => Err(DaskPlannerError::Internal(format!( "Cannot determine Dask SQL type for Arrow type '{arrow_type:?}'" ))), } } } #[pymethods] impl SqlTypeName { #[pyo3(name = "fromString")] #[staticmethod] pub fn py_from_string(input_type: &str) -> PyResult { SqlTypeName::from_string(input_type).map_err(|e| e.into()) } } impl SqlTypeName { pub fn from_string(input_type: &str) -> Result { match input_type.to_uppercase().as_ref() { "ANY" => Ok(SqlTypeName::ANY), "ARRAY" => Ok(SqlTypeName::ARRAY), "NULL" => Ok(SqlTypeName::NULL), "BOOLEAN" => Ok(SqlTypeName::BOOLEAN), "COLUMN_LIST" => Ok(SqlTypeName::COLUMN_LIST), "DISTINCT" => Ok(SqlTypeName::DISTINCT), "CURSOR" => Ok(SqlTypeName::CURSOR), "TINYINT" => Ok(SqlTypeName::TINYINT), "SMALLINT" => Ok(SqlTypeName::SMALLINT), "INT" => Ok(SqlTypeName::INTEGER), "INTEGER" => Ok(SqlTypeName::INTEGER), "BIGINT" => Ok(SqlTypeName::BIGINT), "REAL" => Ok(SqlTypeName::REAL), "FLOAT" => Ok(SqlTypeName::FLOAT), "GEOMETRY" => Ok(SqlTypeName::GEOMETRY), "DOUBLE" => Ok(SqlTypeName::DOUBLE), "TIME" => Ok(SqlTypeName::TIME), "TIME_WITH_LOCAL_TIME_ZONE" => Ok(SqlTypeName::TIME_WITH_LOCAL_TIME_ZONE), "TIMESTAMP" => Ok(SqlTypeName::TIMESTAMP), "TIMESTAMP_WITH_LOCAL_TIME_ZONE" => Ok(SqlTypeName::TIMESTAMP_WITH_LOCAL_TIME_ZONE), "DATE" => Ok(SqlTypeName::DATE), "INTERVAL" => Ok(SqlTypeName::INTERVAL), "INTERVAL_DAY" => Ok(SqlTypeName::INTERVAL_DAY), "INTERVAL_DAY_HOUR" => Ok(SqlTypeName::INTERVAL_DAY_HOUR), "INTERVAL_DAY_MINUTE" => Ok(SqlTypeName::INTERVAL_DAY_MINUTE), "INTERVAL_DAY_SECOND" => Ok(SqlTypeName::INTERVAL_DAY_SECOND), "INTERVAL_HOUR" => Ok(SqlTypeName::INTERVAL_HOUR), "INTERVAL_HOUR_MINUTE" => Ok(SqlTypeName::INTERVAL_HOUR_MINUTE), "INTERVAL_HOUR_SECOND" => Ok(SqlTypeName::INTERVAL_HOUR_SECOND), "INTERVAL_MINUTE" => Ok(SqlTypeName::INTERVAL_MINUTE), "INTERVAL_MINUTE_SECOND" => Ok(SqlTypeName::INTERVAL_MINUTE_SECOND), "INTERVAL_MONTH" => Ok(SqlTypeName::INTERVAL_MONTH), "INTERVAL_SECOND" => Ok(SqlTypeName::INTERVAL_SECOND), "INTERVAL_YEAR" => Ok(SqlTypeName::INTERVAL_YEAR), "INTERVAL_YEAR_MONTH" => Ok(SqlTypeName::INTERVAL_YEAR_MONTH), "MAP" => Ok(SqlTypeName::MAP), "MULTISET" => Ok(SqlTypeName::MULTISET), "OTHER" => Ok(SqlTypeName::OTHER), "ROW" => Ok(SqlTypeName::ROW), "SARG" => Ok(SqlTypeName::SARG), "BINARY" => Ok(SqlTypeName::BINARY), "VARBINARY" => Ok(SqlTypeName::VARBINARY), "CHAR" => Ok(SqlTypeName::CHAR), "VARCHAR" | "STRING" => Ok(SqlTypeName::VARCHAR), "STRUCTURED" => Ok(SqlTypeName::STRUCTURED), "SYMBOL" => Ok(SqlTypeName::SYMBOL), "DECIMAL" => Ok(SqlTypeName::DECIMAL), "DYNAMIC_STAT" => Ok(SqlTypeName::DYNAMIC_STAR), "UNKNOWN" => Ok(SqlTypeName::UNKNOWN), _ => { // complex data type name so use the sqlparser let dialect = DaskDialect {}; let mut tokenizer = Tokenizer::new(&dialect, input_type); let tokens = tokenizer.tokenize().map_err(DaskPlannerError::from)?; let mut parser = Parser::new(&dialect).with_tokens(tokens); match parser.parse_data_type().map_err(DaskPlannerError::from)? { SQLType::Decimal(_) => Ok(SqlTypeName::DECIMAL), SQLType::Binary(_) => Ok(SqlTypeName::BINARY), SQLType::Varbinary(_) => Ok(SqlTypeName::VARBINARY), SQLType::Varchar(_) | SQLType::Nvarchar(_) => Ok(SqlTypeName::VARCHAR), SQLType::Char(_) => Ok(SqlTypeName::CHAR), _ => Err(DaskPlannerError::Internal(format!( "Cannot determine Dask SQL type for '{input_type}'" ))), } } } } } #[cfg(test)] mod test { use crate::sql::types::SqlTypeName; #[test] fn invalid_type_name() { assert_eq!( "Internal Error: Cannot determine Dask SQL type for 'bob'", SqlTypeName::from_string("bob") .expect_err("invalid type name") .to_string() ); } #[test] fn string() { assert_expected("VARCHAR", "string"); } #[test] fn varchar_n() { assert_expected("VARCHAR", "VARCHAR(10)"); } #[test] fn decimal_p_s() { assert_expected("DECIMAL", "DECIMAL(10, 2)"); } fn assert_expected(expected: &str, input: &str) { assert_eq!( expected, &format!("{:?}", SqlTypeName::from_string(input).unwrap()) ); } } ================================================ FILE: src/sql.rs ================================================ pub mod column; pub mod exceptions; pub mod function; pub mod logical; pub mod optimizer; pub mod parser_utils; pub mod preoptimizer; pub mod schema; pub mod statement; pub mod table; pub mod types; use std::{collections::HashMap, sync::Arc}; use datafusion_python::{ datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}, datafusion_common::{ config::ConfigOptions, tree_node::{TreeNode, TreeNodeVisitor, VisitRecursion}, DFSchema, DataFusionError, }, datafusion_expr::{ logical_plan::Extension, AccumulatorFactoryFunction, AggregateUDF, LogicalPlan, ReturnTypeFunction, ScalarFunctionImplementation, ScalarUDF, Signature, StateTypeFunction, TableSource, TypeSignature, Volatility, }, datafusion_sql::{ parser::Statement as DFStatement, planner::{ContextProvider, SqlToRel}, ResolvedTableReference, TableReference, }, }; use log::{debug, warn}; use pyo3::prelude::*; use self::logical::{ create_catalog_schema::CreateCatalogSchemaPlanNode, drop_schema::DropSchemaPlanNode, use_schema::UseSchemaPlanNode, }; use crate::{ dialect::DaskDialect, parser::{DaskParser, DaskStatement}, sql::{ exceptions::{py_optimization_exp, py_parsing_exp, py_runtime_err}, logical::{ alter_schema::AlterSchemaPlanNode, alter_table::AlterTablePlanNode, analyze_table::AnalyzeTablePlanNode, create_experiment::CreateExperimentPlanNode, create_model::CreateModelPlanNode, create_table::CreateTablePlanNode, describe_model::DescribeModelPlanNode, drop_model::DropModelPlanNode, export_model::ExportModelPlanNode, predict_model::PredictModelPlanNode, show_columns::ShowColumnsPlanNode, show_models::ShowModelsPlanNode, show_schemas::ShowSchemasPlanNode, show_tables::ShowTablesPlanNode, PyLogicalPlan, }, preoptimizer::datetime_coercion, }, }; /// DaskSQLContext is main interface used for interacting with DataFusion to /// parse SQL queries, build logical plans, and optimize logical plans. /// /// The following example demonstrates how to generate an optimized LogicalPlan /// from SQL using DaskSQLContext. #[pyclass(name = "DaskSQLContext", module = "dask_sql", subclass)] #[derive(Debug, Clone)] pub struct DaskSQLContext { current_catalog: String, current_schema: String, schemas: HashMap, options: ConfigOptions, optimizer_config: DaskSQLOptimizerConfig, } #[pyclass(name = "DaskSQLOptimizerConfig", module = "dask_sql", subclass)] #[derive(Debug, Clone)] pub struct DaskSQLOptimizerConfig { dynamic_partition_pruning: bool, fact_dimension_ratio: Option, max_fact_tables: Option, preserve_user_order: Option, filter_selectivity: Option, } #[pymethods] impl DaskSQLOptimizerConfig { #[new] pub fn new( dynamic_partition_pruning: bool, fact_dimension_ratio: Option, max_fact_tables: Option, preserve_user_order: Option, filter_selectivity: Option, ) -> Self { Self { dynamic_partition_pruning, fact_dimension_ratio, max_fact_tables, preserve_user_order, filter_selectivity, } } } impl ContextProvider for DaskSQLContext { fn get_table_provider( &self, name: TableReference, ) -> Result, DataFusionError> { let reference: ResolvedTableReference = name .clone() .resolve(&self.current_catalog, &self.current_schema); if reference.catalog != self.current_catalog { // there is a single catalog in Dask SQL return Err(DataFusionError::Plan(format!( "Cannot resolve catalog '{}'", reference.catalog ))); } let schema_name = reference.clone().schema.into_owned(); match self.schemas.get(&schema_name) { Some(schema) => { let mut resp = None; for table in schema.tables.values() { if table.table_name.eq(&name.table()) { // Build the Schema here let mut fields: Vec = Vec::new(); // Iterate through the DaskTable instance and create a Schema instance for (column_name, column_type) in &table.columns { fields.push(Field::new( column_name, DataType::from(column_type.data_type()), true, )); } resp = Some(Schema::new(fields)); } } // If the Table is not found return None. DataFusion will handle the error propagation match resp { Some(e) => { let table_ref = &self .schemas .get(reference.schema.as_ref()) .unwrap() .tables .get(reference.table.as_ref()) .unwrap(); let statistics = &table_ref.statistics; let filepath = &table_ref.filepath; if statistics.get_row_count() == 0.0 { Ok(Arc::new(table::DaskTableSource::new( Arc::new(e), None, filepath.clone(), ))) } else { Ok(Arc::new(table::DaskTableSource::new( Arc::new(e), Some(statistics.clone()), filepath.clone(), ))) } } None => Err(DataFusionError::Plan(format!( "Table '{}.{}.{}' not found", reference.catalog, reference.schema, reference.table ))), } } None => Err(DataFusionError::Plan(format!( "Unable to locate Schema: '{}.{}'", reference.catalog, reference.schema ))), } } fn get_function_meta(&self, name: &str) -> Option> { let fun: ScalarFunctionImplementation = Arc::new(|_| Err(DataFusionError::NotImplemented("".to_string()))); let numeric_datatypes = vec![ DataType::Int8, DataType::Int16, DataType::Int32, DataType::Int64, DataType::UInt8, DataType::UInt16, DataType::UInt32, DataType::UInt64, DataType::Float16, DataType::Float32, DataType::Float64, ]; match name { "year" => { let sig = Signature::exact( vec![DataType::Timestamp(TimeUnit::Nanosecond, None)], Volatility::Immutable, ); let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Int64))); return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun))); } "last_day" => { let sig = Signature::exact( vec![DataType::Timestamp(TimeUnit::Nanosecond, None)], Volatility::Immutable, ); let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Timestamp(TimeUnit::Nanosecond, None)))); return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun))); } "timestampceil" | "timestampfloor" => { // let sig = Signature::exact( // vec![DataType::Timestamp(TimeUnit::Nanosecond, None), DataType::Date64, DataType::Utf8], // Volatility::Immutable, // ); let sig = Signature::one_of( vec![ TypeSignature::Exact(vec![DataType::Date64, DataType::Utf8]), TypeSignature::Exact(vec![ DataType::Timestamp(TimeUnit::Nanosecond, None), DataType::Utf8, ]), ], Volatility::Immutable, ); let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Date64))); return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun))); } "timestampadd" => { let sig = Signature::one_of( vec![ TypeSignature::Exact(vec![ DataType::Utf8, DataType::Int64, DataType::Date64, ]), TypeSignature::Exact(vec![ DataType::Utf8, DataType::Int64, DataType::Timestamp(TimeUnit::Nanosecond, None), ]), TypeSignature::Exact(vec![ DataType::Utf8, DataType::Int64, DataType::Int64, ]), ], Volatility::Immutable, ); let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Date64))); return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun))); } "timestampdiff" => { let sig = Signature::one_of( vec![ TypeSignature::Exact(vec![ DataType::Utf8, DataType::Timestamp(TimeUnit::Nanosecond, None), DataType::Timestamp(TimeUnit::Nanosecond, None), ]), TypeSignature::Exact(vec![ DataType::Utf8, DataType::Date64, DataType::Date64, ]), TypeSignature::Exact(vec![ DataType::Utf8, DataType::Int64, DataType::Int64, ]), ], Volatility::Immutable, ); let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Int64))); return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun))); } "dsql_totimestamp" => { let first_datatypes = vec![ DataType::Int8, DataType::Int16, DataType::Int32, DataType::Int64, DataType::UInt8, DataType::UInt16, DataType::UInt32, DataType::UInt64, DataType::Utf8, ]; let sig = generate_signatures(vec![first_datatypes, vec![DataType::Utf8]]); let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Date64))); return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun))); } "mod" => { let sig = generate_signatures(vec![numeric_datatypes.clone(), numeric_datatypes]); let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Float64))); return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun))); } "cbrt" | "cot" | "degrees" | "radians" | "sign" | "truncate" => { let sig = generate_signatures(vec![numeric_datatypes]); let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Float64))); return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun))); } "rand" => { let sig = Signature::one_of( vec![ TypeSignature::Exact(vec![]), TypeSignature::Exact(vec![DataType::Int64]), ], Volatility::Immutable, ); let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Float64))); return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun))); } "rand_integer" => { let sig = Signature::one_of( vec![ TypeSignature::Exact(vec![DataType::Int64]), TypeSignature::Exact(vec![DataType::Int64, DataType::Int64]), ], Volatility::Immutable, ); let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Int64))); return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun))); } "extract_date" => { let sig = Signature::one_of( vec![ TypeSignature::Exact(vec![DataType::Utf8, DataType::Date64]), TypeSignature::Exact(vec![ DataType::Utf8, DataType::Timestamp(TimeUnit::Nanosecond, None), ]), ], Volatility::Immutable, ); let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Date64))); return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun))); } _ => (), } // Loop through all of the user defined functions for schema in self.schemas.values() { for (fun_name, func_mutex) in &schema.functions { if fun_name.eq(name) { let function = func_mutex.lock().unwrap(); if function.aggregation.eq(&true) { return None; } let sig = { Signature::one_of( function .return_types .keys() .map(|v| TypeSignature::Exact(v.to_vec())) .collect(), Volatility::Immutable, ) }; let function = function.clone(); let rtf: ReturnTypeFunction = Arc::new(move |input_types| { match function.return_types.get(&input_types.to_vec()) { Some(return_type) => Ok(Arc::new(return_type.clone())), None => Err(DataFusionError::Plan(format!( "UDF signature not found for input types {input_types:?}" ))), } }); return Some(Arc::new(ScalarUDF::new( fun_name.as_str(), &sig, &rtf, &fun, ))); } } } None } fn get_aggregate_meta(&self, name: &str) -> Option> { let acc: AccumulatorFactoryFunction = Arc::new(|_return_type| Err(DataFusionError::NotImplemented("".to_string()))); let st: StateTypeFunction = Arc::new(|_| Err(DataFusionError::NotImplemented("".to_string()))); let numeric_datatypes = vec![ DataType::Int8, DataType::Int16, DataType::Int32, DataType::Int64, DataType::UInt8, DataType::UInt16, DataType::UInt32, DataType::UInt64, DataType::Float16, DataType::Float32, DataType::Float64, ]; match name { "every" => { // let sig = generate_signatures(vec![DataType::Boolean]); let sig = Signature::exact(vec![DataType::Boolean], Volatility::Immutable); let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Boolean))); return Some(Arc::new(AggregateUDF::new(name, &sig, &rtf, &acc, &st))); } "bit_and" | "bit_or" => { let sig = generate_signatures(vec![numeric_datatypes]); let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Int64))); return Some(Arc::new(AggregateUDF::new(name, &sig, &rtf, &acc, &st))); } "single_value" => { let sig = generate_signatures(vec![numeric_datatypes]); let rtf: ReturnTypeFunction = Arc::new(|input_types| Ok(Arc::new(input_types[0].clone()))); return Some(Arc::new(AggregateUDF::new(name, &sig, &rtf, &acc, &st))); } "regr_count" => { let sig = generate_signatures(vec![numeric_datatypes.clone(), numeric_datatypes]); let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Int64))); return Some(Arc::new(AggregateUDF::new(name, &sig, &rtf, &acc, &st))); } "regr_syy" | "regr_sxx" => { let sig = generate_signatures(vec![numeric_datatypes.clone(), numeric_datatypes]); let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Float64))); return Some(Arc::new(AggregateUDF::new(name, &sig, &rtf, &acc, &st))); } _ => (), } // Loop through all of the user defined functions for schema in self.schemas.values() { for (fun_name, func_mutex) in &schema.functions { if fun_name.eq(name) { let function = func_mutex.lock().unwrap(); if function.aggregation.eq(&false) { return None; } let sig = { Signature::one_of( function .return_types .keys() .map(|v| TypeSignature::Exact(v.to_vec())) .collect(), Volatility::Immutable, ) }; let function = function.clone(); let rtf: ReturnTypeFunction = Arc::new(move |input_types| { match function.return_types.get(&input_types.to_vec()) { Some(return_type) => Ok(Arc::new(return_type.clone())), None => Err(DataFusionError::Plan(format!( "UDAF signature not found for input types {input_types:?}" ))), } }); return Some(Arc::new(AggregateUDF::new(fun_name, &sig, &rtf, &acc, &st))); } } } None } fn get_variable_type(&self, _: &[String]) -> Option { unimplemented!("RUST: get_variable_type is not yet implemented for DaskSQLContext") } fn options(&self) -> &ConfigOptions { &self.options } fn get_window_meta( &self, _name: &str, ) -> Option> { unimplemented!("RUST: get_window_meta is not yet implemented for DaskSQLContext") } } #[pymethods] impl DaskSQLContext { #[new] pub fn new( default_catalog_name: &str, default_schema_name: &str, optimizer_config: DaskSQLOptimizerConfig, ) -> Self { Self { current_catalog: default_catalog_name.to_owned(), current_schema: default_schema_name.to_owned(), schemas: HashMap::new(), options: ConfigOptions::new(), optimizer_config, } } pub fn set_optimizer_config(&mut self, config: DaskSQLOptimizerConfig) -> PyResult<()> { self.optimizer_config = config; Ok(()) } /// Change the current schema pub fn use_schema(&mut self, schema_name: &str) -> PyResult<()> { if self.schemas.contains_key(schema_name) { self.current_schema = schema_name.to_owned(); Ok(()) } else { Err(py_runtime_err(format!( "Schema: {schema_name} not found in DaskSQLContext" ))) } } /// Register a Schema with the current DaskSQLContext pub fn register_schema( &mut self, schema_name: String, schema: schema::DaskSchema, ) -> PyResult { self.schemas.insert(schema_name, schema); Ok(true) } /// Register a DaskTable instance under the specified schema in the current DaskSQLContext pub fn register_table( &mut self, schema_name: String, table: table::DaskTable, ) -> PyResult { match self.schemas.get_mut(&schema_name) { Some(schema) => { schema.add_table(table); Ok(true) } None => Err(py_runtime_err(format!( "Schema: {schema_name} not found in DaskSQLContext" ))), } } /// Parses a SQL string into an AST presented as a Vec of Statements pub fn parse_sql(&self, sql: &str) -> PyResult> { debug!("parse_sql - '{}'", sql); let dd: DaskDialect = DaskDialect {}; match DaskParser::parse_sql_with_dialect(sql, &dd) { Ok(k) => { let mut statements: Vec = Vec::new(); for statement in k { statements.push(statement.into()); } Ok(statements) } Err(e) => Err(py_parsing_exp(e)), } } /// Creates a non-optimized Relational Algebra LogicalPlan from an AST Statement pub fn logical_relational_algebra( &self, statement: statement::PyStatement, ) -> PyResult { self._logical_relational_algebra(statement.statement) .map(|e| PyLogicalPlan { original_plan: e, current_node: None, }) .map_err(py_parsing_exp) } pub fn run_preoptimizer( &self, existing_plan: logical::PyLogicalPlan, ) -> PyResult { if let Some(plan) = datetime_coercion(&existing_plan.original_plan) { Ok(plan.into()) } else { Ok(existing_plan) } } /// Accepts an existing relational plan, `LogicalPlan`, and optimizes it /// by applying a set of `optimizer` trait implementations against the /// `LogicalPlan` pub fn optimize_relational_algebra( &self, existing_plan: logical::PyLogicalPlan, ) -> PyResult { // Certain queries cannot be optimized. Ex: `EXPLAIN SELECT * FROM test` simply return those plans as is let mut visitor = OptimizablePlanVisitor {}; match existing_plan.original_plan.visit(&mut visitor) { Ok(valid) => { match valid { VisitRecursion::Stop => { // This LogicalPlan does not support Optimization. Return original warn!("This LogicalPlan does not support Optimization. Returning original"); Ok(existing_plan) } _ => { let optimized_plan = optimizer::DaskSqlOptimizer::new( self.optimizer_config.fact_dimension_ratio, self.optimizer_config.max_fact_tables, self.optimizer_config.preserve_user_order, self.optimizer_config.filter_selectivity, ) .optimize(existing_plan.original_plan) .map(|k| PyLogicalPlan { original_plan: k, current_node: None, }) .map_err(py_optimization_exp); if let Ok(optimized_plan) = optimized_plan { if self.optimizer_config.dynamic_partition_pruning { optimizer::DaskSqlOptimizer::dynamic_partition_pruner( self.optimizer_config.fact_dimension_ratio, ) .optimize_once(optimized_plan.original_plan) .map(|k| PyLogicalPlan { original_plan: k, current_node: None, }) .map_err(py_optimization_exp) } else { Ok(optimized_plan) } } else { optimized_plan } } } } Err(e) => Err(py_optimization_exp(e)), } } } /// non-Python methods impl DaskSQLContext { /// Creates a non-optimized Relational Algebra LogicalPlan from an AST Statement pub fn _logical_relational_algebra( &self, dask_statement: DaskStatement, ) -> Result { match dask_statement { DaskStatement::Statement(statement) => { let planner = SqlToRel::new(self); planner.statement_to_plan(DFStatement::Statement(statement)) } DaskStatement::CreateModel(create_model) => Ok(LogicalPlan::Extension(Extension { node: Arc::new(CreateModelPlanNode { schema_name: create_model.schema_name, model_name: create_model.model_name, input: self._logical_relational_algebra(create_model.select)?, if_not_exists: create_model.if_not_exists, or_replace: create_model.or_replace, with_options: create_model.with_options, }), })), DaskStatement::CreateExperiment(create_experiment) => { Ok(LogicalPlan::Extension(Extension { node: Arc::new(CreateExperimentPlanNode { schema_name: create_experiment.schema_name, experiment_name: create_experiment.experiment_name, input: self._logical_relational_algebra(create_experiment.select)?, if_not_exists: create_experiment.if_not_exists, or_replace: create_experiment.or_replace, with_options: create_experiment.with_options, }), })) } DaskStatement::PredictModel(predict_model) => Ok(LogicalPlan::Extension(Extension { node: Arc::new(PredictModelPlanNode { schema_name: predict_model.schema_name, model_name: predict_model.model_name, input: self._logical_relational_algebra(predict_model.select)?, }), })), DaskStatement::DescribeModel(describe_model) => Ok(LogicalPlan::Extension(Extension { node: Arc::new(DescribeModelPlanNode { schema: Arc::new(DFSchema::empty()), schema_name: describe_model.schema_name, model_name: describe_model.model_name, }), })), DaskStatement::CreateCatalogSchema(create_schema) => { Ok(LogicalPlan::Extension(Extension { node: Arc::new(CreateCatalogSchemaPlanNode { schema: Arc::new(DFSchema::empty()), schema_name: create_schema.schema_name, if_not_exists: create_schema.if_not_exists, or_replace: create_schema.or_replace, }), })) } DaskStatement::CreateTable(create_table) => Ok(LogicalPlan::Extension(Extension { node: Arc::new(CreateTablePlanNode { schema: Arc::new(DFSchema::empty()), schema_name: create_table.schema_name, table_name: create_table.table_name, if_not_exists: create_table.if_not_exists, or_replace: create_table.or_replace, with_options: create_table.with_options, }), })), DaskStatement::ExportModel(export_model) => Ok(LogicalPlan::Extension(Extension { node: Arc::new(ExportModelPlanNode { schema: Arc::new(DFSchema::empty()), schema_name: export_model.schema_name, model_name: export_model.model_name, with_options: export_model.with_options, }), })), DaskStatement::DropModel(drop_model) => Ok(LogicalPlan::Extension(Extension { node: Arc::new(DropModelPlanNode { schema_name: drop_model.schema_name, model_name: drop_model.model_name, if_exists: drop_model.if_exists, schema: Arc::new(DFSchema::empty()), }), })), DaskStatement::ShowSchemas(show_schemas) => Ok(LogicalPlan::Extension(Extension { node: Arc::new(ShowSchemasPlanNode { schema: Arc::new(DFSchema::empty()), catalog_name: show_schemas.catalog_name, like: show_schemas.like, }), })), DaskStatement::ShowTables(show_tables) => Ok(LogicalPlan::Extension(Extension { node: Arc::new(ShowTablesPlanNode { schema: Arc::new(DFSchema::empty()), catalog_name: show_tables.catalog_name, schema_name: show_tables.schema_name, }), })), DaskStatement::ShowColumns(show_columns) => Ok(LogicalPlan::Extension(Extension { node: Arc::new(ShowColumnsPlanNode { schema: Arc::new(DFSchema::empty()), table_name: show_columns.table_name, schema_name: show_columns.schema_name, }), })), DaskStatement::ShowModels(show_models) => Ok(LogicalPlan::Extension(Extension { node: Arc::new(ShowModelsPlanNode { schema: Arc::new(DFSchema::empty()), schema_name: show_models.schema_name, }), })), DaskStatement::DropSchema(drop_schema) => Ok(LogicalPlan::Extension(Extension { node: Arc::new(DropSchemaPlanNode { schema: Arc::new(DFSchema::empty()), schema_name: drop_schema.schema_name, if_exists: drop_schema.if_exists, }), })), DaskStatement::UseSchema(use_schema) => Ok(LogicalPlan::Extension(Extension { node: Arc::new(UseSchemaPlanNode { schema: Arc::new(DFSchema::empty()), schema_name: use_schema.schema_name, }), })), DaskStatement::AnalyzeTable(analyze_table) => Ok(LogicalPlan::Extension(Extension { node: Arc::new(AnalyzeTablePlanNode { schema: Arc::new(DFSchema::empty()), table_name: analyze_table.table_name, schema_name: analyze_table.schema_name, columns: analyze_table.columns, }), })), DaskStatement::AlterTable(alter_table) => Ok(LogicalPlan::Extension(Extension { node: Arc::new(AlterTablePlanNode { schema: Arc::new(DFSchema::empty()), old_table_name: alter_table.old_table_name, new_table_name: alter_table.new_table_name, schema_name: alter_table.schema_name, if_exists: alter_table.if_exists, }), })), DaskStatement::AlterSchema(alter_schema) => Ok(LogicalPlan::Extension(Extension { node: Arc::new(AlterSchemaPlanNode { schema: Arc::new(DFSchema::empty()), old_schema_name: alter_schema.old_schema_name, new_schema_name: alter_schema.new_schema_name, }), })), } } } /// Visits each AST node to determine if the plan is valid for optimization or not pub struct OptimizablePlanVisitor; impl TreeNodeVisitor for OptimizablePlanVisitor { type N = LogicalPlan; fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { // If the plan contains an unsupported Node type we flag the plan as un-optimizable here match plan { LogicalPlan::Explain(..) => Ok(VisitRecursion::Stop), _ => Ok(VisitRecursion::Continue), } } fn post_visit(&mut self, _plan: &LogicalPlan) -> Result { Ok(VisitRecursion::Continue) } } fn generate_signatures(cartesian_setup: Vec>) -> Signature { let mut exact_vector = vec![]; let mut datatypes_iter = cartesian_setup.iter(); // First pass if let Some(first_iter) = datatypes_iter.next() { for datatype in first_iter { exact_vector.push(vec![datatype.clone()]); } } // Generate the Cartesian product for iter in datatypes_iter { let mut outer_temp = vec![]; for outer_datatype in exact_vector { for inner_datatype in iter { let mut inner_temp = outer_datatype.clone(); inner_temp.push(inner_datatype.clone()); outer_temp.push(inner_temp); } } exact_vector = outer_temp; } // Create vector of TypeSignatures let mut one_of_vector = vec![]; for vector in exact_vector.iter() { one_of_vector.push(TypeSignature::Exact(vector.clone())); } Signature::one_of(one_of_vector.clone(), Volatility::Immutable) } #[cfg(test)] mod test { use datafusion_python::{ datafusion::arrow::datatypes::DataType, datafusion_expr::{Signature, TypeSignature, Volatility}, }; use crate::sql::generate_signatures; #[test] fn test_generate_signatures() { let sig = generate_signatures(vec![ vec![DataType::Int64, DataType::Float64], vec![DataType::Utf8, DataType::Int64], ]); let expected = Signature::one_of( vec![ TypeSignature::Exact(vec![DataType::Int64, DataType::Utf8]), TypeSignature::Exact(vec![DataType::Int64, DataType::Int64]), TypeSignature::Exact(vec![DataType::Float64, DataType::Utf8]), TypeSignature::Exact(vec![DataType::Float64, DataType::Int64]), ], Volatility::Immutable, ); assert_eq!(sig, expected); } } ================================================ FILE: tests/__init__.py ================================================ ================================================ FILE: tests/integration/__init__.py ================================================ ================================================ FILE: tests/integration/fixtures.py ================================================ import os import tempfile import dask.dataframe as dd import numpy as np import pandas as pd import pytest from dask.datasets import timeseries as dd_timeseries from dask.distributed import Client from tests.utils import assert_eq, convert_nullable_columns try: import cudf # importing to check for JVM segfault import dask_cudf # noqa: F401 from dask_cuda import LocalCUDACluster # noqa: F401 except ImportError: cudf = None dask_cudf = None LocalCUDACluster = None # check if we want to run tests on a distributed client DISTRIBUTED_TESTS = os.getenv("DASK_SQL_DISTRIBUTED_TESTS", "False").lower() in ( "true", "1", ) @pytest.fixture() def df_simple(): return pd.DataFrame({"a": [1, 2, 3], "b": [1.1, 2.2, 3.3]}) @pytest.fixture() def df_wide(): return pd.DataFrame( { "a": [0, 1, 2], "b": [3, 4, 5], "c": [6, 7, 8], "d": [9, 10, 11], "e": [12, 13, 14], } ) @pytest.fixture() def df(): np.random.seed(42) return pd.DataFrame( { "a": [1.0] * 100 + [2.0] * 200 + [3.0] * 400, "b": 10 * np.random.rand(700), } ) @pytest.fixture() def department_table(): return pd.DataFrame({"department_name": ["English", "Math", "Science"]}) @pytest.fixture() def user_table_1(): return pd.DataFrame({"user_id": [2, 1, 2, 3], "b": [3, 3, 1, 3]}) @pytest.fixture() def user_table_2(): return pd.DataFrame({"user_id": [1, 1, 2, 4], "c": [1, 2, 3, 4]}) @pytest.fixture() def long_table(): return pd.DataFrame({"a": [0] * 100 + [1] * 101 + [2] * 103}) @pytest.fixture() def user_table_inf(): return pd.DataFrame({"c": [3, float("inf"), 1]}) @pytest.fixture() def user_table_nan(): return pd.DataFrame({"c": [3, pd.NA, 1]}).astype("UInt8") @pytest.fixture() def string_table(): return pd.DataFrame( { "a": [ "a normal string", "%_%", "^|()-*[]$", "^|()-*[]$\n%_%\na normal string", ] } ) @pytest.fixture() def datetime_table(): return pd.DataFrame( { "timezone": pd.date_range( start="2014-08-01 09:00", freq="8H", periods=6, tz="Europe/Berlin" ), "no_timezone": pd.date_range( start="2014-08-01 09:00", freq="8H", periods=6 ), "utc_timezone": pd.date_range( start="2014-08-01 09:00", freq="8H", periods=6, tz="UTC" ), } ) @pytest.fixture() def timeseries(): return dd_timeseries(freq="1d").reset_index(drop=True) @pytest.fixture() def parquet_ddf(tmpdir): # Write simple parquet dataset df = pd.DataFrame( { "a": [1, 2, 3] * 5, "b": range(15), "c": ["A"] * 15, "d": [ pd.Timestamp("2013-08-01 23:00:00"), pd.Timestamp("2014-09-01 23:00:00"), pd.Timestamp("2015-10-01 23:00:00"), ] * 5, "index": range(15), }, ) dd.from_pandas(df, npartitions=3).to_parquet(os.path.join(tmpdir, "parquet")) # Read back with dask and apply WHERE query return dd.read_parquet(os.path.join(tmpdir, "parquet"), index="index") @pytest.fixture() def gpu_user_table_1(user_table_1): return cudf.from_pandas(user_table_1) if cudf else None @pytest.fixture() def gpu_df(df): return cudf.from_pandas(df) if cudf else None @pytest.fixture() def gpu_long_table(long_table): return cudf.from_pandas(long_table) if cudf else None @pytest.fixture() def gpu_string_table(string_table): return cudf.from_pandas(string_table) if cudf else None @pytest.fixture() def gpu_datetime_table(datetime_table): if cudf: # TODO: remove once `from_pandas` has support for timezone-aware data # https://github.com/rapidsai/cudf/issues/13611 df = datetime_table.copy() df["timezone"] = df["timezone"].dt.tz_localize(None) df["utc_timezone"] = df["utc_timezone"].dt.tz_localize(None) gdf = cudf.from_pandas(df) gdf["timezone"] = gdf["timezone"].dt.tz_localize( str(datetime_table["timezone"].dt.tz) ) gdf["utc_timezone"] = gdf["utc_timezone"].dt.tz_localize( str(datetime_table["utc_timezone"].dt.tz) ) return gdf return None @pytest.fixture() def gpu_timeseries(timeseries): return timeseries.to_backend("cudf") if dask_cudf else None @pytest.fixture() def c( df_simple, df_wide, df, department_table, user_table_1, user_table_2, long_table, user_table_inf, user_table_nan, string_table, datetime_table, timeseries, parquet_ddf, gpu_user_table_1, gpu_df, gpu_long_table, gpu_string_table, gpu_datetime_table, gpu_timeseries, ): dfs = { "df_simple": df_simple, "df_wide": df_wide, "df": df, "department_table": department_table, "user_table_1": user_table_1, "user_table_2": user_table_2, "long_table": long_table, "user_table_inf": user_table_inf, "user_table_nan": user_table_nan, "string_table": string_table, "datetime_table": datetime_table, "timeseries": timeseries, "parquet_ddf": parquet_ddf, "gpu_user_table_1": gpu_user_table_1, "gpu_df": gpu_df, "gpu_long_table": gpu_long_table, "gpu_string_table": gpu_string_table, "gpu_datetime_table": gpu_datetime_table, "gpu_timeseries": gpu_timeseries, } # Lazy import, otherwise the pytest framework has problems from dask_sql.context import Context c = Context() for df_name, df in dfs.items(): if df is None: continue if hasattr(df, "npartitions"): # df is already a dask collection dask_df = df else: dask_df = dd.from_pandas(df, npartitions=3) c.create_table(df_name, dask_df) yield c @pytest.fixture() def temporary_data_file(): temporary_data_file = os.path.join( tempfile.gettempdir(), os.urandom(24).hex() + ".csv" ) yield temporary_data_file if os.path.exists(temporary_data_file): os.unlink(temporary_data_file) @pytest.fixture() def assert_query_gives_same_result(engine): np.random.seed(42) df1 = dd.from_pandas( pd.DataFrame( { "user_id": np.random.choice([1, 2, 3, 4, pd.NA], 100), "a": np.random.rand(100), "b": np.random.randint(-10, 10, 100), } ), npartitions=3, ) df1["user_id"] = df1["user_id"].astype("Int64") df2 = dd.from_pandas( pd.DataFrame( { "user_id": np.random.choice([1, 2, 3, 4], 100), "c": np.random.randint(20, 30, 100), "d": np.random.choice(["a", "b", "c", None], 100), } ), npartitions=3, ) df3 = dd.from_pandas( pd.DataFrame( { "s": [ "".join(np.random.choice(["a", "B", "c", "D"], 10)) for _ in range(100) ] + [None] } ), npartitions=3, ) # the other is a Int64, that makes joining simpler df2["user_id"] = df2["user_id"].astype("Int64") # add some NaNs df1["a"] = df1["a"].apply( lambda a: float("nan") if a > 0.8 else a, meta=("a", "float") ) df1["b_bool"] = df1["b"].apply( lambda b: pd.NA if b > 5 else b < 0, meta=("a", "bool") ) # Lazy import, otherwise the pytest framework has problems from dask_sql.context import Context c = Context() c.create_table("df1", df1) c.create_table("df2", df2) c.create_table("df3", df3) df1.compute().to_sql("df1", engine, index=False, if_exists="replace") df2.compute().to_sql("df2", engine, index=False, if_exists="replace") df3.compute().to_sql("df3", engine, index=False, if_exists="replace") def _assert_query_gives_same_result(query, sort_columns=None, **kwargs): sql_result = pd.read_sql_query(query, engine) dask_result = c.sql(query).compute() # allow that the names are different # as expressions are handled differently sql_result.columns = dask_result.columns sql_result = sql_result.convert_dtypes() dask_result = dask_result.convert_dtypes() convert_nullable_columns(sql_result) convert_nullable_columns(dask_result) assert_eq( sql_result, dask_result, check_dtype=False, check_index=False, **kwargs ) return _assert_query_gives_same_result @pytest.fixture() def gpu_client(request): # allow gpu_client to be used directly as a fixture or parametrized if not hasattr(request, "param") or request.param: with LocalCUDACluster(protocol="tcp") as cluster: with Client(cluster) as client: yield client else: with Client() as client: yield client # use session-wide distributed client if specified otherwise default to standard fixture @pytest.fixture( scope="session" if DISTRIBUTED_TESTS else "function", autouse=DISTRIBUTED_TESTS ) def client(): with Client() as client: yield client ================================================ FILE: tests/integration/test_analyze.py ================================================ import dask.dataframe as dd import pandas as pd from dask_sql.mappings import python_to_sql_type from tests.utils import assert_eq def test_analyze(c, df): result_df = c.sql("ANALYZE TABLE df COMPUTE STATISTICS FOR ALL COLUMNS") # extract table and compute stats with Dask manually expected_df = dd.concat( [ c.sql("SELECT * FROM df").describe(), pd.DataFrame( { col: str(python_to_sql_type(df[col].dtype)).lower() for col in df.columns }, index=["data_type"], ), pd.DataFrame( {col: col for col in df.columns}, index=["col_name"], ), ] ) assert_eq(result_df, expected_df) result_df = c.sql("ANALYZE TABLE df COMPUTE STATISTICS FOR COLUMNS a") assert_eq(result_df, expected_df[["a"]]) ================================================ FILE: tests/integration/test_cmd.py ================================================ from unittest.mock import MagicMock, patch import pytest from dask import config as dask_config from prompt_toolkit.application import create_app_session from prompt_toolkit.input import create_pipe_input from prompt_toolkit.output import DummyOutput from prompt_toolkit.shortcuts import PromptSession from dask_sql._compat import PIPE_INPUT_CONTEXT_MANAGER from dask_sql.cmd import _meta_commands @pytest.fixture(autouse=True, scope="function") def mock_prompt_input(): # TODO: remove if prompt-toolkit min version gets bumped if PIPE_INPUT_CONTEXT_MANAGER: with create_pipe_input() as pipe_input: with create_app_session(input=pipe_input, output=DummyOutput()): yield pipe_input else: pipe_input = create_pipe_input() try: with create_app_session(input=pipe_input, output=DummyOutput()): yield pipe_input finally: pipe_input.close() def _feed_cli_with_input( text, editing_mode=None, clipboard=None, history=None, multiline=False, check_line_ending=True, key_bindings=None, ): """ Create a Prompt, feed it with the given user input and return the CLI object. This returns a (result, Application) tuple. """ # If the given text doesn't end with a newline, the interface won't finish. if check_line_ending: assert text.endswith("\r") inp = create_pipe_input() try: inp.send_text(text) session = PromptSession( input=inp, output=DummyOutput(), editing_mode=editing_mode, history=history, multiline=multiline, clipboard=clipboard, key_bindings=key_bindings, ) result = session.prompt() return session.default_buffer.document, session.app finally: inp.close() def test_meta_commands(c, client, capsys): _meta_commands("?", context=c, client=client) captured = capsys.readouterr() assert "Commands" in captured.out _meta_commands("help", context=c, client=client) captured = capsys.readouterr() assert "Commands" in captured.out _meta_commands("\\d?", context=c, client=client) captured = capsys.readouterr() assert "Commands" in captured.out _meta_commands("\\l", context=c, client=client) captured = capsys.readouterr() assert "Schemas" in captured.out _meta_commands("\\dt", context=c, client=client) captured = capsys.readouterr() assert "Tables" in captured.out _meta_commands("\\dm", context=c, client=client) captured = capsys.readouterr() assert "Models" in captured.out _meta_commands("\\df", context=c, client=client) captured = capsys.readouterr() assert "Functions" in captured.out _meta_commands("\\de", context=c, client=client) captured = capsys.readouterr() assert "Experiments" in captured.out c.create_schema("test_schema") _meta_commands("\\dss test_schema", context=c, client=client) assert c.schema_name == "test_schema" _meta_commands("\\dss not_exists", context=c, client=client) captured = capsys.readouterr() assert "Schema not_exists not available\n" == captured.out # FIXME: Revert to 8787 once https://github.com/dask/distributed/issues/8071 is fixed with pytest.raises( OSError, match="Timed out .* to tcp://localhost:8788 after 5 s", ): with dask_config.set({"distributed.comm.timeouts.connect": 5}): client = _meta_commands("\\dsc localhost:8788", context=c, client=client) assert client.scheduler.__dict__["addr"] == "localhost:8788" def test_connection_info(c, client, capsys): dummy_client = MagicMock() dummy_client.scheduler.__dict__["addr"] = "somewhereonearth:8787" dummy_client.cluster.worker = ["worker1", "worker2"] _meta_commands("\\conninfo", context=c, client=dummy_client) captured = capsys.readouterr() assert "somewhereonearth" in captured.out def test_quit(c, client, capsys): dummy_client = MagicMock() with patch("sys.exit", return_value=lambda: "exit"): _meta_commands("quit", context=c, client=dummy_client) captured = capsys.readouterr() assert captured.out == "Quitting dask-sql ...\n" def test_non_meta_commands(c, client, capsys): _meta_commands("\\x", context=c, client=client) captured = capsys.readouterr() assert ( "The meta command \\x not available, please use commands from below list" in captured.out ) res = _meta_commands("Select 42 as answer", context=c, client=client) captured = capsys.readouterr() assert res is False ================================================ FILE: tests/integration/test_compatibility.py ================================================ """ The tests in this module are taken from the fugue-sql module to test the compatibility with their "understanding" of SQL They run randomized tests and compare with sqlite. There are some changes compared to the fugueSQL tests, especially when it comes to sort order: dask-sql does not enforce a specific order after groupby """ import sqlite3 from datetime import datetime, timedelta import dask.config import numpy as np import pandas as pd import pytest from dask_sql import Context from dask_sql.utils import ParsingException from tests.utils import assert_eq, convert_nullable_columns, skipif_dask_expr_enabled def eq_sqlite(sql, **dfs): c = Context() engine = sqlite3.connect(":memory:") for name, df in dfs.items(): c.create_table(name, df) df.to_sql(name, engine, index=False) dask_result = c.sql(sql).compute().convert_dtypes() sqlite_result = pd.read_sql(sql, engine).convert_dtypes() convert_nullable_columns(dask_result) convert_nullable_columns(sqlite_result) datetime_cols = dask_result.select_dtypes( include=["datetime64[ns]"] ).columns.tolist() for col in datetime_cols: sqlite_result[col] = pd.to_datetime(sqlite_result[col]) sqlite_result = sqlite_result.astype(dask_result.dtypes) assert_eq(dask_result, sqlite_result, check_dtype=False, check_index=False) def make_rand_df(size: int, **kwargs): np.random.seed(0) data = {} for k, v in kwargs.items(): if not isinstance(v, tuple): v = (v, 0.0) dt, null_ct = v[0], v[1] if dt is int: s = np.random.randint(10, size=size) elif dt is bool: s = np.where(np.random.randint(2, size=size), True, False) elif dt is float: s = np.random.rand(size) elif dt is str: r = [f"ssssss{x}" for x in range(10)] c = np.random.randint(10, size=size) s = np.array([r[x] for x in c]) elif dt is pd.StringDtype: r = [f"ssssss{x}" for x in range(10)] c = np.random.randint(10, size=size) s = np.array([r[x] for x in c]) s = pd.array(s, dtype="string") elif dt is datetime: rt = [datetime(2020, 1, 1) + timedelta(days=x) for x in range(10)] c = np.random.randint(10, size=size) s = np.array([rt[x] for x in c]) else: raise NotImplementedError ps = pd.Series(s) if null_ct > 0: idx = np.random.choice(size, null_ct, replace=False).tolist() ps[idx] = None data[k] = ps return pd.DataFrame(data) def test_basic_select_from(): df = make_rand_df(5, a=(int, 2), b=(str, 3), c=(float, 4)) eq_sqlite("SELECT 1 AS a, 1.5 AS b, 'x' AS c") eq_sqlite("SELECT 1+2 AS a, 1.5*3 AS b, 'x' AS c") eq_sqlite("SELECT * FROM a", a=df) eq_sqlite("SELECT * FROM a AS x", a=df) eq_sqlite("SELECT b AS bb, a+1-2*3.0/4 AS cc, x.* FROM a AS x", a=df) eq_sqlite("SELECT *, 1 AS x, 2.5 AS y, 'z' AS z FROM a AS x", a=df) eq_sqlite("SELECT *, -(1.0+a)/3 AS x, +(2.5) AS y FROM a AS x", a=df) def test_case_when(): a = make_rand_df(100, a=(int, 20), b=(str, 30), c=(float, 40)) eq_sqlite( """ SELECT a,b,c, CASE WHEN a<10 THEN a+3 WHEN c<0.5 THEN a+5 ELSE (1+2)*3 + a END AS d FROM a """, a=a, ) def test_drop_duplicates(): # simplest a = make_rand_df(100, a=int, b=int) eq_sqlite( """ SELECT DISTINCT b, a FROM a ORDER BY a NULLS LAST, b NULLS FIRST """, a=a, ) # mix of number and nan a = make_rand_df(100, a=(int, 50), b=(int, 50)) eq_sqlite( """ SELECT DISTINCT b, a FROM a ORDER BY a NULLS LAST, b NULLS FIRST """, a=a, ) # mix of number and string and nulls a = make_rand_df(100, a=(int, 50), b=(str, 50), c=float) eq_sqlite( """ SELECT DISTINCT b, a FROM a ORDER BY a NULLS LAST, b NULLS FIRST """, a=a, ) def test_order_by_no_limit(): a = make_rand_df(100, a=(int, 50), b=(str, 50), c=float) eq_sqlite( """ SELECT DISTINCT b, a FROM a ORDER BY a NULLS LAST, b NULLS FIRST """, a=a, ) def test_order_by_limit(): a = make_rand_df(100, a=(int, 50), b=(str, 50), c=float) eq_sqlite( """ SELECT DISTINCT b, a FROM a LIMIT 0 """, a=a, ) eq_sqlite( """ SELECT DISTINCT b, a FROM a ORDER BY a NULLS FIRST, b NULLS FIRST LIMIT 2 """, a=a, ) eq_sqlite( """ SELECT b, a FROM a ORDER BY a NULLS LAST, b NULLS FIRST LIMIT 10 """, a=a, ) def test_where(): df = make_rand_df(100, a=(int, 30), b=(str, 30), c=(float, 30)) eq_sqlite("SELECT * FROM a WHERE TRUE OR TRUE", a=df) eq_sqlite("SELECT * FROM a WHERE TRUE AND TRUE", a=df) eq_sqlite("SELECT * FROM a WHERE FALSE OR FALSE", a=df) eq_sqlite("SELECT * FROM a WHERE FALSE AND FALSE", a=df) eq_sqlite("SELECT * FROM a WHERE TRUE OR b<='ssssss8'", a=df) eq_sqlite("SELECT * FROM a WHERE TRUE AND b<='ssssss8'", a=df) eq_sqlite("SELECT * FROM a WHERE FALSE OR b<='ssssss8'", a=df) eq_sqlite("SELECT * FROM a WHERE FALSE AND b<='ssssss8'", a=df) eq_sqlite("SELECT * FROM a WHERE a=10 OR b<='ssssss8'", a=df) eq_sqlite("SELECT * FROM a WHERE c IS NOT NULL OR (a<5 AND b IS NOT NULL)", a=df) df = make_rand_df(100, a=(float, 30), b=(float, 30), c=(float, 30)) eq_sqlite("SELECT * FROM a WHERE a<0.5 AND b<0.5 AND c<0.5", a=df) eq_sqlite("SELECT * FROM a WHERE a<0.5 OR b<0.5 AND c<0.5", a=df) eq_sqlite("SELECT * FROM a WHERE a IS NULL OR (b<0.5 AND c<0.5)", a=df) eq_sqlite("SELECT * FROM a WHERE a*b IS NULL OR (b*c<0.5 AND c*a<0.5)", a=df) def test_in_between(): df = make_rand_df(10, a=(int, 3), b=(str, 3)) eq_sqlite("SELECT * FROM a WHERE a IN (2,4,6)", a=df) eq_sqlite("SELECT * FROM a WHERE a BETWEEN 2 AND 4+1", a=df) eq_sqlite("SELECT * FROM a WHERE a NOT IN (2,4,6) AND a IS NOT NULL", a=df) eq_sqlite("SELECT * FROM a WHERE a NOT BETWEEN 2 AND 4+1 AND a IS NOT NULL", a=df) eq_sqlite( "SELECT * FROM a WHERE SUBSTR(b,1,2) IN ('ss','s') AND a NOT BETWEEN 3 AND 5 and a IS NOT NULL", a=df, ) def test_join_inner(): a = make_rand_df(100, a=(int, 40), b=(str, 40), c=(float, 40)) b = make_rand_df(80, d=(float, 10), a=(int, 10), b=(str, 10)) eq_sqlite( """ SELECT a.*, d, d*c AS x FROM a INNER JOIN b ON a.a=b.a AND a.b=b.b ORDER BY a.a NULLS FIRST, a.b NULLS FIRST, a.c NULLS FIRST, d NULLS FIRST """, a=a, b=b, ) def test_join_left(): a = make_rand_df(100, a=(int, 40), b=(str, 40), c=(float, 40)) b = make_rand_df(80, d=(float, 10), a=(int, 10), b=(str, 10)) eq_sqlite( """ SELECT a.*, d, d*c AS x FROM a LEFT JOIN b ON a.a=b.a AND a.b=b.b ORDER BY a.a NULLS FIRST, a.b NULLS FIRST, a.c NULLS FIRST, d NULLS FIRST """, a=a, b=b, ) def test_join_cross(): a = make_rand_df(10, a=(int, 4), b=(str, 4), c=(float, 4)) b = make_rand_df(20, dd=(float, 1), aa=(int, 1), bb=(str, 1)) eq_sqlite( """ SELECT * FROM a CROSS JOIN b ORDER BY a.a NULLS FIRST, a.b NULLS FIRST, a.c NULLS FIRST, dd NULLS FIRST """, a=a, b=b, ) def test_join_multi(): a = make_rand_df(100, a=(int, 40), b=(str, 40), c=(float, 40)) b = make_rand_df(80, d=(float, 10), a=(int, 10), b=(str, 10)) c = make_rand_df(80, dd=(float, 10), a=(int, 10), b=(str, 10)) eq_sqlite( """ SELECT a.*,d,dd FROM a INNER JOIN b ON a.a=b.a AND a.b=b.b INNER JOIN c ON a.a=c.a AND c.b=b.b ORDER BY a.a NULLS FIRST, a.b NULLS FIRST, a.c NULLS FIRST, dd NULLS FIRST, d NULLS FIRST """, a=a, b=b, c=c, ) def test_single_agg_count_no_group_by(): a = make_rand_df( 100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40) ) eq_sqlite( """ SELECT COUNT(a) AS c_a, COUNT(DISTINCT a) AS cd_a FROM a """, a=a, ) def test_multi_agg_count_no_group_by(): a = make_rand_df( 100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40) ) eq_sqlite( """ SELECT COUNT(a) AS c_a, COUNT(DISTINCT a) AS cd_a, COUNT(b) AS c_b, COUNT(DISTINCT b) AS cd_b, COUNT(c) AS c_c, COUNT(DISTINCT c) AS cd_c, COUNT(d) AS c_d, COUNT(DISTINCT d) AS cd_d, COUNT(e) AS c_e, COUNT(DISTINCT e) AS cd_e FROM a """, a=a, ) def test_multi_agg_count_no_group_by_dupe_distinct(): a = make_rand_df( 100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40) ) # note that this test repeats the expression `COUNT(DISTINCT a)` eq_sqlite( """ SELECT COUNT(a) AS c_a, COUNT(DISTINCT a) AS cd_a, COUNT(b) AS c_b, COUNT(DISTINCT b) AS cd_b, COUNT(c) AS c_c, COUNT(DISTINCT c) AS cd_c, COUNT(d) AS c_d, COUNT(DISTINCT d) AS cd_d, COUNT(e) AS c_e, COUNT(DISTINCT a) AS cd_e FROM a """, a=a, ) def test_agg_count_distinct_group_by(): a = make_rand_df( 100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40) ) eq_sqlite( """ SELECT a, COUNT(DISTINCT b) AS cd_b FROM a GROUP BY a ORDER BY a NULLS FIRST """, a=a, ) def test_agg_count_no_group_by(): a = make_rand_df( 100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40) ) eq_sqlite( """ SELECT COUNT(a) AS cd_a FROM a """, a=a, ) def test_agg_count_distinct_no_group_by(): a = make_rand_df( 100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40) ) eq_sqlite( """ SELECT COUNT(DISTINCT a) AS cd_a FROM a """, a=a, ) def test_agg_count(): a = make_rand_df( 100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40) ) # note that this test repeats the expression `COUNT(DISTINCT a)` eq_sqlite( """ SELECT a, b, a+1 AS c, COUNT(c) AS c_c, COUNT(DISTINCT c) AS cd_c, COUNT(d) AS c_d, COUNT(DISTINCT d) AS cd_d, COUNT(e) AS c_e, COUNT(DISTINCT a) AS cd_e FROM a GROUP BY a, b ORDER BY a NULLS FIRST, b NULLS FIRST """, a=a, ) def test_agg_sum_avg_no_group_by(): eq_sqlite( """ SELECT SUM(a) AS sum_a, AVG(a) AS avg_a FROM a """, a=pd.DataFrame({"a": [float("2.3")]}), ) a = make_rand_df( 100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40) ) eq_sqlite( """ SELECT SUM(a) AS sum_a, AVG(a) AS avg_a, SUM(c) AS sum_c, AVG(c) AS avg_c, SUM(e) AS sum_e, AVG(e) AS avg_e, SUM(a)+AVG(e) AS mix_1, SUM(a+e) AS mix_2 FROM a """, a=a, ) def test_agg_sum_avg(): a = make_rand_df( 100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40) ) eq_sqlite( """ SELECT a,b, a+1 AS c, SUM(c) AS sum_c, AVG(c) AS avg_c, SUM(e) AS sum_e, AVG(e) AS avg_e, SUM(a)+AVG(e) AS mix_1, SUM(a+e) AS mix_2 FROM a GROUP BY a, b ORDER BY a NULLS FIRST, b NULLS FIRST """, a=a, ) def test_agg_min_max_no_group_by(): a = make_rand_df( 100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40), f=(pd.StringDtype, 40), g=(datetime, 40), ) eq_sqlite( """ SELECT MIN(a) AS min_a, MAX(a) AS max_a, MIN(b) AS min_b, MAX(b) AS max_b, MIN(c) AS min_c, MAX(c) AS max_c, MIN(d) AS min_d, MAX(d) AS max_d, MIN(e) AS min_e, MAX(e) AS max_e, MIN(f) as min_f, MAX(f) as max_f, MIN(g) as min_g, MAX(g) as max_g, MIN(a+e) AS mix_1, MIN(a)+MIN(e) AS mix_2 FROM a """, a=a, ) def test_agg_min_max(): a = make_rand_df( 100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40), f=(pd.StringDtype, 40), g=(datetime, 40), ) eq_sqlite( """ SELECT a, b, a+1 AS c, MIN(c) AS min_c, MAX(c) AS max_c, MIN(d) AS min_d, MAX(d) AS max_d, MIN(e) AS min_e, MAX(e) AS max_e, MIN(f) AS min_f, MAX(f) AS max_f, MIN(g) AS min_g, MAX(g) AS max_g, MIN(a+e) AS mix_1, MIN(a)+MIN(e) AS mix_2 FROM a GROUP BY a, b ORDER BY a NULLS FIRST, b NULLS FIRST """, a=a, ) def test_window_row_number(): a = make_rand_df(10, a=int, b=(float, 5)) eq_sqlite( """ SELECT *, ROW_NUMBER() OVER (ORDER BY a ASC, b DESC NULLS FIRST) AS a1, ROW_NUMBER() OVER (ORDER BY a ASC, b DESC NULLS LAST) AS a2, ROW_NUMBER() OVER (ORDER BY a ASC, b ASC NULLS FIRST) AS a3, ROW_NUMBER() OVER (ORDER BY a ASC, b ASC NULLS LAST) AS a4, ROW_NUMBER() OVER (PARTITION BY a ORDER BY a,b DESC NULLS FIRST) AS a5 FROM a ORDER BY a, b NULLS FIRST """, a=a, ) a = make_rand_df(100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=float) eq_sqlite( """ SELECT *, ROW_NUMBER() OVER (ORDER BY a ASC NULLS LAST, b DESC NULLS FIRST, e) AS a1, ROW_NUMBER() OVER (ORDER BY a ASC NULLS FIRST, b DESC NULLS LAST, e) AS a2, ROW_NUMBER() OVER (PARTITION BY a ORDER BY a NULLS FIRST, b DESC NULLS LAST, e) AS a3, ROW_NUMBER() OVER (PARTITION BY a,c ORDER BY a NULLS FIRST, b DESC NULLS LAST, e) AS a4 FROM a ORDER BY a NULLS FIRST, b NULLS FIRST, c NULLS FIRST, d NULLS FIRST, e """, a=a, ) def test_window_row_number_partition_by(): a = make_rand_df(100, a=int, b=(float, 50)) eq_sqlite( """ SELECT *, ROW_NUMBER() OVER (PARTITION BY a ORDER BY a, b DESC NULLS FIRST) AS a5 FROM a ORDER BY a, b NULLS FIRST, a5 """, a=a, ) a = make_rand_df(100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=float) eq_sqlite( """ SELECT *, ROW_NUMBER() OVER (PARTITION BY a ORDER BY a NULLS FIRST, b DESC NULLS FIRST, e) AS a3, ROW_NUMBER() OVER (PARTITION BY a,c ORDER BY a NULLS FIRST, b DESC NULLS FIRST, e) AS a4 FROM a ORDER BY a NULLS FIRST, b NULLS FIRST, c NULLS FIRST, d NULLS FIRST, e """, a=a, ) @pytest.mark.xfail( reason="Need to implement rank/lead/lag window functions, see https://github.com/dask-contrib/dask-sql/issues/878" ) def test_window_ranks(): a = make_rand_df(100, a=int, b=(float, 50), c=(str, 50)) eq_sqlite( """ SELECT *, RANK() OVER (PARTITION BY a ORDER BY b DESC NULLS FIRST, c) AS a1, DENSE_RANK() OVER (ORDER BY a ASC, b DESC NULLS LAST, c DESC) AS a2, PERCENT_RANK() OVER (ORDER BY a ASC, b ASC NULLS LAST, c) AS a4 FROM a """, a=a, ) @pytest.mark.xfail( reason="Need to implement rank/lead/lag window functions, see https://github.com/dask-contrib/dask-sql/issues/878" ) def test_window_ranks_partition_by(): a = make_rand_df(100, a=int, b=(float, 50), c=(str, 50)) eq_sqlite( """ SELECT *, RANK() OVER (PARTITION BY a ORDER BY b DESC NULLS FIRST, c) AS a1, DENSE_RANK() OVER (PARTITION BY a ORDER BY a ASC, b DESC NULLS LAST, c DESC) AS a2, PERCENT_RANK() OVER (PARTITION BY a ORDER BY a ASC, b ASC NULLS LAST, c) AS a4 FROM a """, a=a, ) @pytest.mark.xfail( reason="Need to implement rank/lead/lag window functions, see https://github.com/dask-contrib/dask-sql/issues/878" ) def test_window_lead_lag(): a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50)) eq_sqlite( """ SELECT LEAD(b,1) OVER (ORDER BY a) AS a1, LEAD(b,2,10) OVER (ORDER BY a) AS a2, LEAD(b,1) OVER (PARTITION BY c ORDER BY a) AS a3, LEAD(b,1) OVER (PARTITION BY c ORDER BY b, a ASC NULLS LAST) AS a5, LAG(b,1) OVER (ORDER BY a) AS b1, LAG(b,2,10) OVER (ORDER BY a) AS b2, LAG(b,1) OVER (PARTITION BY c ORDER BY a) AS b3, LAG(b,1) OVER (PARTITION BY c ORDER BY b, a ASC NULLS LAST) AS b5 FROM a """, a=a, ) @pytest.mark.xfail( reason="Need to implement rank/lead/lag window functions, see https://github.com/dask-contrib/dask-sql/issues/878" ) def test_window_lead_lag_partition_by(): a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50)) eq_sqlite( """ SELECT LEAD(b,1,10) OVER (PARTITION BY c ORDER BY a) AS a3, LEAD(b,1) OVER (PARTITION BY c ORDER BY b, a ASC NULLS LAST) AS a5, LAG(b,1) OVER (PARTITION BY c ORDER BY a) AS b3, LAG(b,1) OVER (PARTITION BY c ORDER BY b, a ASC NULLS LAST) AS b5 FROM a """, a=a, ) def test_window_sum_avg(): a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50)) for func in ["SUM", "AVG"]: eq_sqlite( f""" SELECT a,b, {func}(b) OVER () AS a1, {func}(b) OVER (PARTITION BY c) AS a2, {func}(b+a) OVER (PARTITION BY c,b) AS a3, {func}(b+a) OVER (PARTITION BY b ORDER BY a NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS a4, {func}(b+a) OVER (PARTITION BY b ORDER BY a DESC NULLS FIRST ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS a5, {func}(b+a) OVER (PARTITION BY b ORDER BY a NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS a6 FROM a ORDER BY a NULLS FIRST, b NULLS FIRST, c NULLS FIRST """, a=a, ) # irregular windows eq_sqlite( f""" SELECT a,b, {func}(b) OVER (PARTITION BY b ORDER BY a DESC NULLS FIRST ROWS BETWEEN 2 PRECEDING AND 1 PRECEDING) AS a6, {func}(b) OVER (PARTITION BY b ORDER BY a DESC NULLS FIRST ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING) AS a7, {func}(b) OVER (PARTITION BY b ORDER BY a DESC NULLS FIRST ROWS BETWEEN 2 PRECEDING AND UNBOUNDED FOLLOWING) AS a8 FROM a ORDER BY a NULLS FIRST, b NULLS FIRST, c NULLS FIRST """, a=a, ) def test_window_sum_avg_partition_by(): a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50)) for func in ["SUM", "AVG"]: eq_sqlite( f""" SELECT a,b, {func}(b+a) OVER (PARTITION BY c,b) AS a3, {func}(b+a) OVER (PARTITION BY b ORDER BY a NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS a4, {func}(b+a) OVER (PARTITION BY b ORDER BY a DESC NULLS FIRST ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS a5, {func}(b+a) OVER (PARTITION BY b ORDER BY a NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS a6 FROM a ORDER BY a NULLS FIRST, b NULLS FIRST, c NULLS FIRST """, a=a, ) # irregular windows eq_sqlite( f""" SELECT a,b, {func}(b) OVER (PARTITION BY b ORDER BY a DESC NULLS FIRST ROWS BETWEEN 2 PRECEDING AND 1 PRECEDING) AS a6, {func}(b) OVER (PARTITION BY b ORDER BY a DESC NULLS FIRST ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING) AS a7, {func}(b) OVER (PARTITION BY b ORDER BY a DESC NULLS FIRST ROWS BETWEEN 2 PRECEDING AND UNBOUNDED FOLLOWING) AS a8 FROM a ORDER BY a NULLS FIRST, b NULLS FIRST, c NULLS FIRST """, a=a, ) def test_window_min_max(): for func in ["MIN", "MAX"]: a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50)) eq_sqlite( f""" SELECT a,b, {func}(b) OVER () AS a1, {func}(b) OVER (PARTITION BY c) AS a2, {func}(b+a) OVER (PARTITION BY c,b) AS a3, {func}(b+a) OVER (PARTITION BY b ORDER BY a ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS a4, {func}(b+a) OVER (PARTITION BY b ORDER BY a DESC ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS a5, {func}(b+a) OVER (PARTITION BY b ORDER BY a ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS a6 FROM a ORDER BY a NULLS FIRST, b NULLS FIRST, c NULLS FIRST """, a=a, ) # irregular windows eq_sqlite( f""" SELECT a,b, {func}(b) OVER (ORDER BY a DESC ROWS BETWEEN 2 PRECEDING AND 1 PRECEDING) AS a6, {func}(b) OVER (ORDER BY a DESC ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING) AS a7, {func}(b) OVER (ORDER BY a DESC ROWS BETWEEN 2 PRECEDING AND UNBOUNDED FOLLOWING) AS a8 FROM a ORDER BY a NULLS FIRST, b NULLS FIRST, c NULLS FIRST """, a=a, ) b = make_rand_df(10, a=float, b=(int, 0), c=(str, 0)) eq_sqlite( f""" SELECT a,b, {func}(b) OVER (PARTITION BY b ORDER BY a DESC ROWS BETWEEN 2 PRECEDING AND 1 PRECEDING) AS a6 FROM a ORDER BY a NULLS FIRST, b NULLS FIRST, c NULLS FIRST """, a=b, ) def test_window_min_max_partition_by(): for func in ["MIN", "MAX"]: a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50)) eq_sqlite( f""" SELECT a,b, {func}(b) OVER (PARTITION BY c) AS a2, {func}(b+a) OVER (PARTITION BY c,b) AS a3, {func}(b+a) OVER (PARTITION BY b ORDER BY a ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS a4, {func}(b+a) OVER (PARTITION BY b ORDER BY a DESC ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS a5, {func}(b+a) OVER (PARTITION BY b ORDER BY a ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS a6 FROM a ORDER BY a NULLS FIRST, b NULLS FIRST, c NULLS FIRST """, a=a, ) b = make_rand_df(10, a=float, b=(int, 0), c=(str, 0)) eq_sqlite( f""" SELECT a,b, {func}(b) OVER (PARTITION BY b ORDER BY a DESC ROWS BETWEEN 2 PRECEDING AND 1 PRECEDING) AS a6 FROM a ORDER BY a NULLS FIRST, b NULLS FIRST, c NULLS FIRST """, a=b, ) # TODO: investigate source of window count deadlocks @skipif_dask_expr_enabled("Deadlocks with query planning enabled") def test_window_count(): for func in ["COUNT"]: a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50)) eq_sqlite( f""" SELECT a,b, {func}(b) OVER () AS a1, {func}(b) OVER (PARTITION BY c) AS a2, {func}(b+a) OVER (PARTITION BY c,b) AS a3, {func}(b+a) OVER (PARTITION BY b ORDER BY a ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS a4, {func}(b+a) OVER (PARTITION BY b ORDER BY a DESC ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS a5, {func}(b+a) OVER (PARTITION BY b ORDER BY a ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS a6, {func}(c) OVER () AS b1, {func}(c) OVER (PARTITION BY c) AS b2, {func}(c) OVER (PARTITION BY c,b) AS b3, {func}(c) OVER (PARTITION BY b ORDER BY a ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS b4, {func}(c) OVER (PARTITION BY b ORDER BY a DESC ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS b5, {func}(c) OVER (PARTITION BY b ORDER BY a ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS b6 FROM a ORDER BY a NULLS FIRST, b NULLS FIRST, c NULLS FIRST """, a=a, ) # irregular windows eq_sqlite( f""" SELECT a,b, {func}(b) OVER (ORDER BY a DESC ROWS BETWEEN 2 PRECEDING AND 0 PRECEDING) AS a6, {func}(b) OVER (PARTITION BY c ORDER BY a DESC ROWS BETWEEN 2 PRECEDING AND 0 PRECEDING) AS a9, {func}(c) OVER (ORDER BY a DESC ROWS BETWEEN 2 PRECEDING AND 0 PRECEDING) AS b6, {func}(c) OVER (PARTITION BY c ORDER BY a DESC ROWS BETWEEN 2 PRECEDING AND 0 PRECEDING) AS b9 FROM a ORDER BY a NULLS FIRST, b NULLS FIRST, c NULLS FIRST """, a=a, ) # TODO: investigate source of window count deadlocks @skipif_dask_expr_enabled("Deadlocks with query planning enabled") def test_window_count_partition_by(): for func in ["COUNT"]: a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50)) eq_sqlite( f""" SELECT a,b, {func}(b) OVER (PARTITION BY c) AS a2, {func}(b+a) OVER (PARTITION BY c,b) AS a3, {func}(b+a) OVER (PARTITION BY b ORDER BY a ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS a4, {func}(b+a) OVER (PARTITION BY b ORDER BY a DESC ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS a5, {func}(b+a) OVER (PARTITION BY b ORDER BY a ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS a6, {func}(c) OVER (PARTITION BY c) AS b2, {func}(c) OVER (PARTITION BY c,b) AS b3, {func}(c) OVER (PARTITION BY b ORDER BY a ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS b4, {func}(c) OVER (PARTITION BY b ORDER BY a DESC ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS b5, {func}(c) OVER (PARTITION BY b ORDER BY a ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS b6 FROM a ORDER BY a NULLS FIRST, b NULLS FIRST, c NULLS FIRST """, a=a, ) # irregular windows eq_sqlite( f""" SELECT a,b, {func}(b) OVER (PARTITION BY c ORDER BY a DESC ROWS BETWEEN 2 PRECEDING AND 0 PRECEDING) AS a9, {func}(c) OVER (PARTITION BY c ORDER BY a DESC ROWS BETWEEN 2 PRECEDING AND 0 PRECEDING) AS b9 FROM a ORDER BY a NULLS FIRST, b NULLS FIRST, c NULLS FIRST """, a=a, ) def test_nested_query(): a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50)) eq_sqlite( """ SELECT * FROM ( SELECT *, ROW_NUMBER() OVER (PARTITION BY c ORDER BY b NULLS FIRST, a ASC NULLS LAST) AS r FROM a) WHERE r=1 ORDER BY a NULLS LAST, b NULLS LAST, c NULLS LAST """, a=a, ) def test_union(): a = make_rand_df(30, b=(int, 10), c=(str, 10)) b = make_rand_df(80, b=(int, 50), c=(str, 50)) c = make_rand_df(100, b=(int, 50), c=(str, 50)) eq_sqlite( """ SELECT * FROM a UNION SELECT * FROM b UNION SELECT * FROM c ORDER BY b NULLS FIRST, c NULLS FIRST """, a=a, b=b, c=c, ) eq_sqlite( """ SELECT * FROM a UNION ALL SELECT * FROM b UNION ALL SELECT * FROM c ORDER BY b NULLS FIRST, c NULLS FIRST """, a=a, b=b, c=c, ) @pytest.mark.xfail( reason="'ANTI' joins not supported yet, see https://github.com/dask-contrib/dask-sql/issues/879" ) def test_except(): a = make_rand_df(30, b=(int, 10), c=(str, 10)) b = make_rand_df(80, b=(int, 50), c=(str, 50)) c = make_rand_df(100, b=(int, 50), c=(str, 50)) eq_sqlite( """ SELECT * FROM c EXCEPT SELECT * FROM b EXCEPT SELECT * FROM c """, a=a, b=b, c=c, ) @pytest.mark.xfail( reason="INTERSECT is not compliant with SQLite, see https://github.com/dask-contrib/dask-sql/issues/880" ) def test_intersect(): a = make_rand_df(30, b=(int, 10), c=(str, 10)) b = make_rand_df(80, b=(int, 50), c=(str, 50)) c = make_rand_df(100, b=(int, 50), c=(str, 50)) eq_sqlite( """ SELECT * FROM c INTERSECT SELECT * FROM b INTERSECT SELECT * FROM c """, a=a, b=b, c=c, ) def test_with(): a = make_rand_df(30, a=(int, 10), b=(str, 10)) b = make_rand_df(80, ax=(int, 10), bx=(str, 10)) eq_sqlite( """ WITH aa AS ( SELECT a AS aa, b AS bb FROM a ), c AS ( SELECT aa-1 AS aa, bb FROM aa ) SELECT * FROM c UNION SELECT * FROM b ORDER BY aa NULLS FIRST, bb NULLS FIRST """, a=a, b=b, ) def test_integration_1(): a = make_rand_df(100, a=int, b=str, c=float, d=int, e=bool, f=str, g=str, h=float) eq_sqlite( """ WITH a1 AS ( SELECT a+1 AS a, b, c FROM a ), a2 AS ( SELECT a,MAX(b) AS b_max, AVG(c) AS c_avg FROM a GROUP BY a ), a3 AS ( SELECT d+2 AS d, f, g, h FROM a WHERE e ) SELECT a1.a,b,c,b_max,c_avg,f,g,h FROM a1 INNER JOIN a2 ON a1.a=a2.a LEFT JOIN a3 ON a1.a=a3.d ORDER BY a1.a NULLS FIRST, b NULLS FIRST, c NULLS FIRST, f NULLS FIRST, g NULLS FIRST, h NULLS FIRST """, a=a, ) @pytest.mark.parametrize( "case_sensitive", [ False, pytest.param( True, marks=pytest.mark.xfail( reason="https://github.com/dask-contrib/dask-sql/issues/1092" ), ), ], ) def test_query_case_sensitivity(case_sensitive): c = Context() df = pd.DataFrame({"id": [0, 1], "VAL": [1, 2]}) c.create_table("test", df) q1 = "select ID from test" q2 = "select val from test" q3 = "select Id, VAl from test" with dask.config.set({"sql.identifier.case_sensitive": case_sensitive}): if case_sensitive: with pytest.raises(ParsingException): c.sql(q1) with pytest.raises(ParsingException): c.sql(q2) with pytest.raises(ParsingException): c.sql(q3) result = c.sql("SELECT VAL from test") assert_eq(result, df[["VAL"]]) else: df.columns = df.columns.str.lower() result = c.sql(q1) assert_eq(result, df[["id"]]) result = c.sql(q2) assert_eq(result, df[["val"]]) result = c.sql(q3) assert_eq(result, df[["id", "val"]]) def test_column_name_starting_with_number(): c = Context() df = pd.DataFrame({"a": range(10), "1b": range(10)}) c.create_table("df", df) result = c.sql( """ SELECT "1b" AS x FROM df """ ) expected = pd.DataFrame({"x": range(10)}) assert_eq(result, expected) result = c.sql( """ SELECT (CASE WHEN "1b"=1 THEN 0 END) AS x FROM df """ ) expected = pd.DataFrame( {"x": [None, 0, None, None, None, None, None, None, None, None]} ) assert_eq(result, expected) ================================================ FILE: tests/integration/test_complex.py ================================================ from dask.datasets import timeseries def test_complex_query(c): df = timeseries(freq="1d").persist() c.create_table("timeseries", df) result = c.sql( """ SELECT lhs.name, lhs.id, lhs.x FROM timeseries AS lhs JOIN ( SELECT name AS max_name, MAX(x) AS max_x FROM timeseries GROUP BY name ) AS rhs ON lhs.name = rhs.max_name AND lhs.x = rhs.max_x """ ).compute() assert len(result) > 0 ================================================ FILE: tests/integration/test_create.py ================================================ import dask.dataframe as dd import pandas as pd import pytest import dask_sql from tests.utils import assert_eq @pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)]) def test_create_from_csv(c, df, temporary_data_file, gpu): df.to_csv(temporary_data_file, index=False) c.sql( f""" CREATE TABLE new_table WITH ( location = '{temporary_data_file}', format = 'csv', gpu = {gpu} ) """ ) result_df = c.sql( """ SELECT * FROM new_table """ ) assert_eq(result_df, df) @pytest.mark.parametrize( "gpu", [ False, pytest.param(True, marks=pytest.mark.gpu), ], ) def test_cluster_memory(client, c, df, gpu): client.publish_dataset(df=dd.from_pandas(df, npartitions=1)) c.sql( f""" CREATE TABLE new_table WITH ( location = 'df', format = 'memory', gpu = {gpu} ) """ ) return_df = c.sql( """ SELECT * FROM new_table """ ) assert_eq(df, return_df) client.unpublish_dataset("df") @pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)]) def test_create_from_csv_persist(c, df, temporary_data_file, gpu): df.to_csv(temporary_data_file, index=False) c.sql( f""" CREATE TABLE new_table WITH ( location = '{temporary_data_file}', format = 'csv', persist = True, gpu = {gpu} ) """ ) return_df = c.sql( """ SELECT * FROM new_table """ ) assert_eq(df, return_df) def test_wrong_create(c): with pytest.raises(AttributeError): c.sql( """ CREATE TABLE new_table WITH ( format = 'csv' ) """ ) with pytest.raises(AttributeError): c.sql( """ CREATE TABLE new_table WITH ( format = 'strange', location = 'some/path' ) """ ) def test_create_from_query(c, df): with pytest.raises(RuntimeError): c.sql( """ CREATE OR REPLACE TABLE other.new_table AS ( SELECT * FROM df ) """ ) c.sql( """ CREATE OR REPLACE TABLE new_table AS ( SELECT * FROM df ) """ ) return_df = c.sql( """ SELECT * FROM new_table """ ) assert_eq(df, return_df) with pytest.raises(RuntimeError): c.sql( """ CREATE OR REPLACE VIEW other.new_table AS ( SELECT * FROM df ) """ ) c.sql( """ CREATE OR REPLACE VIEW new_table AS ( SELECT * FROM df ) """ ) return_df = c.sql( """ SELECT * FROM new_table """ ) assert_eq(df, return_df) @pytest.mark.parametrize( "gpu", [ False, pytest.param( True, marks=pytest.mark.gpu, ), ], ) def test_view_table_persist(c, temporary_data_file, df, gpu): df.to_csv(temporary_data_file, index=False) c.sql( f""" CREATE TABLE new_table WITH ( location = '{temporary_data_file}', format = 'csv', gpu = {gpu} ) """ ) # Views should change, when the original data changes # Tables should not change, when the original data changes c.sql( """ CREATE VIEW count_view AS ( SELECT COUNT(*) AS c FROM new_table ) """ ) c.sql( """ CREATE TABLE count_table AS ( SELECT COUNT(*) AS c FROM new_table ) """ ) from_view = c.sql("SELECT c FROM count_view") from_table = c.sql("SELECT c FROM count_table") assert_eq(from_view, pd.DataFrame({"c": [700]})) assert_eq(from_table, pd.DataFrame({"c": [700]})) df.iloc[:10].to_csv(temporary_data_file, index=False) from_view = c.sql("SELECT c FROM count_view") from_table = c.sql("SELECT c FROM count_table") assert_eq(from_view, pd.DataFrame({"c": [10]})) assert_eq(from_table, pd.DataFrame({"c": [700]})) def test_replace_and_error(c, temporary_data_file, df): c.sql( """ CREATE TABLE new_table AS ( SELECT 1 AS a ) """ ) assert_eq( c.sql("SELECT a FROM new_table"), pd.DataFrame({"a": [1]}), check_dtype=False, ) with pytest.raises(RuntimeError): c.sql( """ CREATE TABLE new_table AS ( SELECT 1 ) """ ) c.sql( """ CREATE TABLE IF NOT EXISTS new_table AS ( SELECT 2 AS a ) """ ) assert_eq( c.sql("SELECT a FROM new_table"), pd.DataFrame({"a": [1]}), check_dtype=False, ) c.sql( """ CREATE OR REPLACE TABLE new_table AS ( SELECT 2 AS a ) """ ) assert_eq( c.sql("SELECT a FROM new_table"), pd.DataFrame({"a": [2]}), check_dtype=False, ) c.sql("DROP TABLE new_table") with pytest.raises(dask_sql.utils.ParsingException): c.sql("SELECT a FROM new_table") c.sql( """ CREATE TABLE IF NOT EXISTS new_table AS ( SELECT 3 AS a ) """ ) assert_eq( c.sql("SELECT a FROM new_table"), pd.DataFrame({"a": [3]}), check_dtype=False, ) df.to_csv(temporary_data_file, index=False) with pytest.raises(RuntimeError): c.sql( f""" CREATE TABLE new_table WITH ( location = '{temporary_data_file}', format = 'csv' ) """ ) c.sql( f""" CREATE TABLE IF NOT EXISTS new_table WITH ( location = '{temporary_data_file}', format = 'csv' ) """ ) assert_eq( c.sql("SELECT a FROM new_table"), pd.DataFrame({"a": [3]}), check_dtype=False, ) c.sql( f""" CREATE OR REPLACE TABLE new_table WITH ( location = '{temporary_data_file}', format = 'csv' ) """ ) result_df = c.sql("SELECT * FROM new_table") assert_eq(result_df, df) def test_drop(c): with pytest.raises(RuntimeError): c.sql("DROP TABLE new_table") c.sql("DROP TABLE IF EXISTS new_table") c.sql( """ CREATE TABLE new_table AS ( SELECT 1 AS a ) """ ) with pytest.raises(RuntimeError): c.sql("DROP TABLE other.new_table") c.sql("DROP TABLE IF EXISTS new_table") with pytest.raises(dask_sql.utils.ParsingException): c.sql("SELECT a FROM new_table") def test_create_gpu_error(c, df, temporary_data_file): try: import cudf except ImportError: cudf = None if cudf is not None: pytest.skip("GPU-related import errors only need to be checked on CPU") with pytest.raises(ModuleNotFoundError): c.create_table("new_table", df, gpu=True) with pytest.raises(ModuleNotFoundError): c.create_table("new_table", dd.from_pandas(df, npartitions=2), gpu=True) df.to_csv(temporary_data_file, index=False) with pytest.raises(ModuleNotFoundError): c.sql( f""" CREATE TABLE new_table WITH ( location = '{temporary_data_file}', format = 'csv', gpu = True ) """ ) ================================================ FILE: tests/integration/test_distributeby.py ================================================ import dask.dataframe as dd import pandas as pd import pytest @pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)]) def test_distribute_by(c, gpu): df = pd.DataFrame({"id": [0, 1, 2, 1, 2, 3], "val": [0, 1, 2, 1, 2, 3]}) ddf = dd.from_pandas(df, npartitions=2) c.create_table("test", ddf, gpu=gpu) partitioned_ddf = c.sql( """ SELECT id FROM test DISTRIBUTE BY id """ ) part_0_ids = partitioned_ddf.get_partition(0).compute().id.unique() part_1_ids = partitioned_ddf.get_partition(1).compute().id.unique() if gpu: part_0_ids = part_0_ids.to_pandas() part_1_ids = part_1_ids.to_pandas() assert bool(set(part_0_ids) & set(part_1_ids)) is False ================================================ FILE: tests/integration/test_explain.py ================================================ import dask.dataframe as dd import pandas as pd import pytest from dask_sql import Statistics @pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)]) def test_sql_query_explain(c, gpu): df = dd.from_pandas(pd.DataFrame({"a": [1, 2, 3]}), npartitions=1) c.create_table("df", df, gpu=gpu) sql_string = c.sql("EXPLAIN SELECT * FROM df") assert sql_string.startswith("Projection: df.a\n") sql_string = c.sql( "EXPLAIN SELECT MIN(a) AS a_min FROM other_df GROUP BY a", dataframes={"other_df": df}, gpu=gpu, ) assert sql_string.startswith("Projection: MIN(other_df.a) AS a_min\n") assert "Aggregate: groupBy=[[other_df.a]], aggr=[[MIN(other_df.a)]]" in sql_string @pytest.mark.xfail(reason="Need to add statistics to Rust optimizer") @pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)]) def test_statistics_explain(c, gpu): df = dd.from_pandas(pd.DataFrame({"a": [1, 2, 3]}), npartitions=1) c.create_table("df", df, statistics=Statistics(row_count=1337), gpu=gpu) sql_string = c.explain("SELECT * FROM df") assert sql_string.startswith( "DaskTableScan(table=[[root, df]]): rowcount = 1337.0, cumulative cost = {1337.0 rows, 1338.0 cpu, 0.0 io}, id = " ) ================================================ FILE: tests/integration/test_filter.py ================================================ import dask import dask.dataframe as dd import pandas as pd import pytest from dask.utils_test import hlg_layer from packaging.version import parse as parseVersion from tests.utils import assert_eq, skipif_dask_expr_enabled DASK_GT_2022_4_2 = parseVersion(dask.__version__) >= parseVersion("2022.4.2") def test_filter(c, df): return_df = c.sql("SELECT * FROM df WHERE a < 2") expected_df = df[df["a"] < 2] assert_eq(return_df, expected_df) def test_filter_scalar(c, df): return_df = c.sql("SELECT * FROM df WHERE True") expected_df = df assert_eq(return_df, expected_df) return_df = c.sql("SELECT * FROM df WHERE False") expected_df = df.head(0) assert_eq(return_df, expected_df, check_index_type=False) return_df = c.sql("SELECT * FROM df WHERE (1 = 1)") expected_df = df assert_eq(return_df, expected_df) return_df = c.sql("SELECT * FROM df WHERE (1 = 0)") expected_df = df.head(0) assert_eq(return_df, expected_df, check_index_type=False) def test_filter_complicated(c, df): return_df = c.sql("SELECT * FROM df WHERE a < 3 AND (b > 1 AND b < 3)") expected_df = df[((df["a"] < 3) & ((df["b"] > 1) & (df["b"] < 3)))] assert_eq( return_df, expected_df, ) def test_filter_with_nan(c): return_df = c.sql("SELECT * FROM user_table_nan WHERE c = 3") expected_df = pd.DataFrame({"c": [3]}, dtype="Int8") assert_eq( return_df, expected_df, ) def test_string_filter(c, string_table): return_df = c.sql("SELECT * FROM string_table WHERE a = 'a normal string'") assert_eq( return_df, string_table.head(1), ) # Condition needs to specifically check on `M` since this the literal `M` # was getting parsed as a datetime dtype return_df = c.sql("SELECT * from string_table WHERE a = 'M'") expected_df = string_table[string_table["a"] == "M"] assert_eq(return_df, expected_df) @pytest.mark.parametrize( "input_table", [ "datetime_table", pytest.param( "gpu_datetime_table", marks=(pytest.mark.gpu), ), ], ) def test_filter_cast_date(c, input_table, request): datetime_table = request.getfixturevalue(input_table) return_df = c.sql( f""" SELECT * FROM {input_table} WHERE CAST(timezone AS DATE) > DATE '2014-08-01' """ ) expected_df = datetime_table[ datetime_table["timezone"].dt.tz_localize(None).dt.floor("D").astype(" pd.Timestamp("2014-08-01") ] assert_eq(return_df, expected_df) @pytest.mark.parametrize( "input_table", [ "datetime_table", pytest.param( "gpu_datetime_table", marks=(pytest.mark.gpu), ), ], ) @pytest.mark.xfail( reason="Need support for non-UTC timezoned literals, see https://github.com/dask-contrib/dask-sql/issues/1193" ) def test_filter_cast_timestamp(c, input_table, request): datetime_table = request.getfixturevalue(input_table) return_df = c.sql( f""" SELECT * FROM {input_table} WHERE CAST(timezone AS TIMESTAMP) >= TIMESTAMP '2014-08-01 23:00:00+00' """ ) expected_df = datetime_table[ datetime_table["timezone"].astype("= pd.Timestamp("2014-08-01 23:00:00") ] assert_eq(return_df, expected_df) def test_filter_year(c): df = pd.DataFrame({"year": [2015, 2016], "month": [2, 3], "day": [4, 5]}) df["dt"] = pd.to_datetime(df) c.create_table("datetime_test", df) return_df = c.sql("select * from datetime_test where year(dt) < 2016") expected_df = df[df["year"] < 2016] assert_eq(expected_df, return_df) @pytest.mark.parametrize( "query,df_func,filters", [ ( "SELECT * FROM parquet_ddf WHERE b < 10", lambda x: x[x["b"] < 10], [[("b", "<", 10)]], ), ( "SELECT * FROM parquet_ddf WHERE a < 3 AND (b > 1 AND b < 5)", lambda x: x[(x["a"] < 3) & ((x["b"] > 1) & (x["b"] < 5))], [[("a", "<", 3), ("b", ">", 1), ("b", "<", 5)]], ), ( "SELECT * FROM parquet_ddf WHERE (b > 5 AND b < 10) OR a = 1", lambda x: x[((x["b"] > 5) & (x["b"] < 10)) | (x["a"] == 1)], [[("b", ">", 5), ("b", "<", 10)], [("a", "==", 1)]], ), pytest.param( "SELECT * FROM parquet_ddf WHERE b IN (1, 6)", lambda x: x[(x["b"] == 1) | (x["b"] == 6)], [[("b", "==", 1)], [("b", "==", 6)]], ), pytest.param( "SELECT * FROM parquet_ddf WHERE b IN (1, 3, 5, 6)", lambda x: x[x["b"].isin([1, 3, 5, 6])], [[("b", "in", (1, 3, 5, 6))]], ), pytest.param( "SELECT * FROM parquet_ddf WHERE c IN ('A', 'B', 'C', 'D')", lambda x: x[x["c"].isin(["A", "B", "C", "D"])], [[("c", "in", ("A", "B", "C", "D"))]], ), pytest.param( "SELECT * FROM parquet_ddf WHERE b NOT IN (1, 6)", lambda x: x[(x["b"] != 1) & (x["b"] != 6)], [[("b", "!=", 1), ("b", "!=", 6)]], ), pytest.param( "SELECT * FROM parquet_ddf WHERE b NOT IN (1, 3, 5, 6)", lambda x: x[~x["b"].isin([1, 3, 5, 6])], [[("b", "not in", (1, 3, 5, 6))]], ), ( "SELECT a FROM parquet_ddf WHERE (b > 5 AND b < 10) OR a = 1", lambda x: x[((x["b"] > 5) & (x["b"] < 10)) | (x["a"] == 1)][["a"]], [[("b", ">", 5), ("b", "<", 10)], [("a", "==", 1)]], ), ( # Original filters NOT in disjunctive normal form "SELECT a FROM parquet_ddf WHERE (parquet_ddf.b > 3 AND parquet_ddf.b < 10 OR parquet_ddf.a = 1) AND (parquet_ddf.c = 'A')", lambda x: x[ ((x["b"] > 3) & (x["b"] < 10) | (x["a"] == 1)) & (x["c"] == "A") ][["a"]], [ [("c", "==", "A"), ("b", ">", 3), ("b", "<", 10)], [("a", "==", 1), ("c", "==", "A")], ], ), ( # The predicate-pushdown optimization will be skipped here, # because datetime accessors are not supported. However, # the query should still succeed. "SELECT * FROM parquet_ddf WHERE year(d) < 2015", lambda x: x[x["d"].dt.year < 2015], None, ), ], ) @skipif_dask_expr_enabled() def test_predicate_pushdown(c, parquet_ddf, query, df_func, filters): # Check for predicate pushdown. # We can use the `hlg_layer` utility to make sure the # `filters` field has been populated in `creation_info` return_df = c.sql(query) expect_filters = filters got_filters = hlg_layer(return_df.dask, "read-parquet").creation_info["kwargs"][ "filters" ] if expect_filters: got_filters = frozenset(frozenset(v) for v in got_filters) expect_filters = frozenset(frozenset(v) for v in filters) assert got_filters == expect_filters # Check computed result is correct df = parquet_ddf expected_df = df_func(df) # divisions aren't equal for older dask versions assert_eq( return_df, expected_df, check_index=False, check_divisions=DASK_GT_2022_4_2 ) def test_filtered_csv(tmpdir, c): # Predicate pushdown is NOT supported for CSV data. # This test just checks that the "attempted" # predicate-pushdown logic does not lead to # any unexpected errors # Write simple csv dataset df = pd.DataFrame( { "a": [1, 2, 3] * 5, "b": range(15), "c": ["A"] * 15, }, ) dd.from_pandas(df, npartitions=3).to_csv(tmpdir + "/*.csv", index=False) # Read back with dask and apply WHERE query csv_ddf = dd.read_csv(tmpdir + "/*.csv") try: c.create_table("my_csv_table", csv_ddf) return_df = c.sql("SELECT * FROM my_csv_table WHERE b < 10") finally: c.drop_table("my_csv_table") # Check computed result is correct df = csv_ddf expected_df = df[df["b"] < 10] assert_eq(return_df, expected_df) @pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)]) def test_filter_decimal(c, gpu): df = pd.DataFrame( { "a": [304.5, 35.305, 9.043, 102.424, 53.34], "b": [2.2, 82.4, 42, 76.9, 54.4], "c": [1, 2, 2, 5, 9], } ) c.create_table("df", df, gpu=gpu) result_df = c.sql( """ SELECT c FROM df WHERE CAST(a AS DECIMAL) < CAST(b AS DECIMAL) """ ) expected_df = df.loc[df.a < df.b][["c"]] assert_eq(result_df, expected_df) result_df = c.sql( """ SELECT CAST(b AS DECIMAL) as b FROM df WHERE CAST(a AS DECIMAL) < DECIMAL '100.2' """ ) # decimal precision doesn't match up with pandas floats if gpu: result_df["b"] = result_df["b"].astype("float64") expected_df = df.loc[df.a < 100.2][["b"]] assert_eq(result_df, expected_df, check_index=False) c.drop_table("df") @skipif_dask_expr_enabled() def test_predicate_pushdown_isna(tmpdir): from dask_sql.context import Context c = Context() path = str(tmpdir) dd.from_pandas( pd.DataFrame( { "a": [1, 2, None] * 5, "b": range(15), "index": range(15), } ), npartitions=3, ).to_parquet(path + "/df1") df1 = dd.read_parquet(path + "/df1", index="index") c.create_table("df1", df1) dd.from_pandas( pd.DataFrame( { "a": [None, 2, 3] * 5, "b": range(15), "index": range(15), }, ), npartitions=3, ).to_parquet(path + "/df2") df2 = dd.read_parquet(path + "/df2", index="index") c.create_table("df2", df2) return_df = c.sql("SELECT df1.a FROM df1, df2 WHERE df1.a = df2.a") # Check for predicate pushdown filters = [[("a", "is not", None)]] got_filters = hlg_layer(return_df.dask, "read-parquet").creation_info["kwargs"][ "filters" ] got_filters = frozenset(frozenset(v) for v in got_filters) expect_filters = frozenset(frozenset(v) for v in filters) assert got_filters == expect_filters assert all(return_df.compute() == 2) assert len(return_df) == 25 ================================================ FILE: tests/integration/test_fugue.py ================================================ import dask.dataframe as dd import pandas as pd import pytest from dask_sql import Context from tests.utils import assert_eq fugue_sql = pytest.importorskip("fugue_sql") from dask_sql.integrations.fugue import fsql_dask # noqa: E402 def test_fugue_workflow(client): dag = fugue_sql.FugueSQLWorkflow() df = dag.df([[0, "hello"], [1, "world"]], "a:int64,b:str") dag("SELECT * FROM df WHERE a > 0 YIELD DATAFRAME AS result") result = dag.run("dask") return_df = result["result"].as_pandas() assert_eq(return_df, pd.DataFrame({"a": [1], "b": ["world"]})) result = dag.run(client) return_df = result["result"].as_pandas() assert_eq(return_df, pd.DataFrame({"a": [1], "b": ["world"]})) def test_fugue_fsql(client): pdf = pd.DataFrame([[0, "hello"], [1, "world"]], columns=["a", "b"]) dag = fugue_sql.fsql( "SELECT * FROM df WHERE a > 0 YIELD DATAFRAME AS result", df=pdf, ) result = dag.run("dask") return_df = result["result"].as_pandas() assert_eq(return_df, pd.DataFrame({"a": [1], "b": ["world"]})) result = dag.run(client) return_df = result["result"].as_pandas() assert_eq(return_df, pd.DataFrame({"a": [1], "b": ["world"]})) @pytest.mark.flaky(reruns=4, condition="sys.version_info < (3, 10)") def test_dask_fsql(client): def assert_fsql(df: pd.DataFrame) -> None: assert_eq(df, pd.DataFrame({"a": [1]})) # the simplest case: the SQL does not use any input and does not generate output fsql_dask( """ CREATE [[0],[1]] SCHEMA a:long SELECT * WHERE a>0 OUTPUT USING assert_fsql """ ) # it can directly use the dataframes inside dask-sql Context c = Context() c.create_table( "df", dd.from_pandas(pd.DataFrame([[0], [1]], columns=["a"]), npartitions=2) ) fsql_dask( """ SELECT * FROM df WHERE a>0 OUTPUT USING assert_fsql """, c, ) # for dataframes with name, they can register back to the Context (register=True) # the return of fsql is the dict of all dask dataframes with explicit names result = fsql_dask( """ x=SELECT * FROM df WHERE a>0 OUTPUT USING assert_fsql """, c, register=True, ) assert isinstance(result["x"], dd.DataFrame) assert "x" in c.schema[c.schema_name].tables # integration test with fugue transformer extension c = Context() c.create_table( "df1", dd.from_pandas( pd.DataFrame([[0, 1], [1, 2]], columns=["a", "b"]), npartitions=2 ), ) c.create_table( "df2", dd.from_pandas( pd.DataFrame([[1, 2], [3, 4], [-4, 5]], columns=["a", "b"]), npartitions=2 ), ) # schema: * def cumsum(df: pd.DataFrame) -> pd.DataFrame: return df.cumsum() fsql_dask( """ data = SELECT * FROM df1 WHERE a>0 UNION ALL SELECT * FROM df2 WHERE a>0 PERSIST result1 = TRANSFORM data PREPARTITION BY a PRESORT b USING cumsum result2 = TRANSFORM data PREPARTITION BY b PRESORT a USING cumsum PRINT result1, result2 """, c, register=True, ) assert "result1" in c.schema[c.schema_name].tables assert "result2" in c.schema[c.schema_name].tables ================================================ FILE: tests/integration/test_function.py ================================================ import itertools import operator import sys import dask.dataframe as dd import numpy as np import pytest from dask_sql.utils import ParsingException from tests.utils import assert_eq def test_custom_function(c, df): def f(x): return x**2 c.register_function(f, "f", [("x", np.float64)], np.float64) return_df = c.sql("SELECT F(a) AS a FROM df") assert_eq(return_df, df[["a"]] ** 2) def test_custom_function_row(c, df): def f(row): return row["x"] ** 2 c.register_function(f, "f", [("x", np.float64)], np.float64, row_udf=True) return_df = c.sql("SELECT F(a) AS a FROM df") assert_eq(return_df, df[["a"]] ** 2) @pytest.mark.parametrize("colnames", list(itertools.combinations(["a", "b", "c"], 2))) def test_custom_function_any_colnames(colnames, df_wide, c): # a third column is needed def f(row): return row["x"] + row["y"] colname_x, colname_y = colnames c.register_function( f, "f", [("x", np.int64), ("y", np.int64)], np.int64, row_udf=True ) return_df = c.sql(f"SELECT F({colname_x},{colname_y}) FROM df_wide") expect = df_wide[colname_x] + df_wide[colname_y] got = return_df.iloc[:, 0] assert_eq(expect, got, check_names=False) @pytest.mark.parametrize( "retty", [np.float64, np.float32, np.int64, np.int32, np.int16, np.int8, np.bool_], ) def test_custom_function_row_return_types(c, df, retty): def f(row): return row["x"] ** 2 c.register_function(f, "f", [("x", np.float64)], retty, row_udf=True) return_df = c.sql("SELECT F(a) AS a FROM df") assert_eq(return_df, (df[["a"]] ** 2).astype(retty)) # Test row UDFs with one arg @pytest.mark.parametrize("k", [1, 1.5, True]) @pytest.mark.parametrize( "op", [operator.add, operator.sub, operator.mul, operator.truediv] ) @pytest.mark.parametrize("retty", [np.int64, np.float64, np.bool_]) def test_custom_function_row_args(c, df, k, op, retty): const_type = np.dtype(type(k)).type if sys.platform == "win32" and const_type == np.int32: const_type = np.int64 def f(row, k): return op(row["a"], k) c.register_function( f, "f", [("a", np.float64), ("k", const_type)], retty, row_udf=True ) return_df = c.sql(f"SELECT F(a, {k}) as a from df") expected_df = op(df[["a"]], k).astype(retty) assert_eq(return_df, expected_df) # Test row UDFs with two args @pytest.mark.parametrize("k2", [1, 1.5, True]) @pytest.mark.parametrize("k1", [1, 1.5, True]) @pytest.mark.parametrize( "op", [operator.add, operator.sub, operator.mul, operator.truediv] ) @pytest.mark.parametrize("retty", [np.int64, np.float64, np.bool_]) def test_custom_function_row_two_args(c, df, k1, k2, op, retty): const_type_k1 = np.dtype(type(k1)).type const_type_k2 = np.dtype(type(k2)).type if sys.platform == "win32": if const_type_k1 == np.int32: const_type_k1 = np.int64 if const_type_k2 == np.int32: const_type_k2 = np.int64 def f(row, k1, k2): x = op(row["a"], k1) y = op(x, k2) return y c.register_function( f, "f", [("a", np.float64), ("k1", const_type_k1), ("k2", const_type_k2)], retty, row_udf=True, ) return_df = c.sql(f"SELECT F(a, {k1}, {k2}) as a from df") expected_df = op(op(df[["a"]], k1), k2).astype(retty) assert_eq(return_df, expected_df) def test_multiple_definitions(c, df_simple): def f(x): return x**2 c.register_function(f, "f", [("x", np.float64)], np.float64) c.register_function(f, "f", [("x", np.int64)], np.int64) return_df = c.sql( """ SELECT F(a) AS a, f(b) AS b FROM df_simple """ ) expected_df = df_simple[["a", "b"]] ** 2 assert_eq(return_df, expected_df) def f(x): return x**3 c.register_function(f, "f", [("x", np.float64)], np.float64, replace=True) c.register_function(f, "f", [("x", np.int64)], np.int64) return_df = c.sql( """ SELECT F(a) AS a, f(b) AS b FROM df_simple """ ) expected_df = df_simple[["a", "b"]] ** 3 assert_eq(return_df, expected_df) def test_aggregate_function(c): fagg = dd.Aggregation("f", lambda x: x.sum(), lambda x: x.sum()) c.register_aggregation(fagg, "fagg", [("x", np.float64)], np.float64) return_df = c.sql( """ SELECT FAGG(b) AS test, SUM(b) AS "S" FROM df """ ) assert_eq(return_df["test"], return_df["S"], check_names=False) def test_reregistration(c): def f(x): return x**2 # The same is fine c.register_function(f, "f", [("x", np.float64)], np.float64) c.register_function(f, "f", [("x", np.int64)], np.int64) def f(x): return x**3 # A different not with pytest.raises(ValueError): c.register_function(f, "f", [("x", np.float64)], np.float64) # only if we replace it c.register_function(f, "f", [("x", np.float64)], np.float64, replace=True) fagg = dd.Aggregation("f", lambda x: x.sum(), lambda x: x.sum()) c.register_aggregation(fagg, "fagg", [("x", np.float64)], np.float64) c.register_aggregation(fagg, "fagg", [("x", np.int64)], np.int64) fagg = dd.Aggregation("f", lambda x: x.mean(), lambda x: x.mean()) with pytest.raises(ValueError): c.register_aggregation(fagg, "fagg", [("x", np.float64)], np.float64) c.register_aggregation(fagg, "fagg", [("x", np.float64)], np.float64, replace=True) @pytest.mark.parametrize("dtype", [np.timedelta64, None, "a string"]) def test_unsupported_dtype(c, dtype): def f(x): return x**2 # test that an invalid return type raises with pytest.raises(NotImplementedError): c.register_function(f, "f", [("x", np.int64)], dtype) # test that an invalid param type raises with pytest.raises(NotImplementedError): c.register_function(f, "f", [("x", dtype)], np.int64) # TODO: explore implicitly casting inputs to the expected types consistently def test_wrong_input_type(c): def f(a): return a c.register_function(f, "f", [("a", np.int64)], np.int64) with pytest.raises(ParsingException): c.sql("SELECT F(CAST(a AS INT)) AS a FROM df") ================================================ FILE: tests/integration/test_groupby.py ================================================ import dask.dataframe as dd import numpy as np import pandas as pd import pytest from dask.datasets import timeseries from tests.utils import assert_eq @pytest.fixture() def timeseries_df(c): pdf = timeseries(freq="1d").compute().reset_index(drop=True) # input nans in pandas dataframe col1_index = np.random.randint(0, 30, size=int(pdf.shape[0] * 0.2)) col2_index = np.random.randint(0, 30, size=int(pdf.shape[0] * 0.3)) pdf.loc[col1_index, "x"] = np.nan pdf.loc[col2_index, "y"] = np.nan c.create_table("timeseries", pdf, persist=True) return None def test_group_by(c): return_df = c.sql( """ SELECT user_id, SUM(b) AS "S" FROM user_table_1 GROUP BY user_id """ ) expected_df = pd.DataFrame({"user_id": [1, 2, 3], "S": [3, 4, 3]}) assert_eq(return_df.sort_values("user_id").reset_index(drop=True), expected_df) @pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)]) def test_group_by_multi(c, gpu): df = pd.DataFrame({"a": [1, 2, 3], "b": [1, 1, 2]}) c.create_table("df", df, gpu=gpu) result_df = c.sql( """ SELECT SUM(a) AS s, AVG(a) AS av, COUNT(a) AS c FROM df GROUP BY b """ ) expected_df = pd.DataFrame( { "s": df.groupby("b").sum()["a"], "av": df.groupby("b").mean()["a"], "c": df.groupby("b").count()["a"], } ) assert_eq(result_df, expected_df, check_index=False) c.drop_table("df") def test_group_by_all(c, df): result_df = c.sql( """ SELECT SUM(b) AS "S", SUM(2) AS "X" FROM user_table_1 """ ) expected_df = pd.DataFrame({"S": [10], "X": [8]}) assert_eq(result_df, expected_df) result_df = c.sql( """ SELECT SUM(a) AS sum_a, AVG(a) AS avg_a, SUM(b) AS sum_b, AVG(b) AS avg_b, SUM(a)+AVG(b) AS mix_1, SUM(a+b) AS mix_2, AVG(a+b) AS mix_3 FROM df """ ) expected_df = pd.DataFrame( { "sum_a": [df.a.sum()], "avg_a": [df.a.mean()], "sum_b": [df.b.sum()], "avg_b": [df.b.mean()], "mix_1": [df.a.sum() + df.b.mean()], "mix_2": [(df.a + df.b).sum()], "mix_3": [(df.a + df.b).mean()], } ) assert_eq(result_df, expected_df) def test_group_by_filtered(c): return_df = c.sql( """ SELECT SUM(b) FILTER (WHERE user_id = 2) AS "S1", SUM(b) "S2" FROM user_table_1 """ ) expected_df = pd.DataFrame({"S1": [4], "S2": [10]}, dtype="int64") assert_eq(return_df, expected_df) return_df = c.sql( """ SELECT user_id, SUM(b) FILTER (WHERE user_id = 2) AS "S1", SUM(b) "S2" FROM user_table_1 GROUP BY user_id """ ) expected_df = pd.DataFrame( { "user_id": [1, 2, 3], "S1": [np.NaN, 4.0, np.NaN], "S2": [3, 4, 3], }, ) assert_eq(return_df, expected_df, check_index=False) return_df = c.sql( """ SELECT SUM(b) FILTER (WHERE user_id = 2) AS "S1" FROM user_table_1 """ ) expected_df = pd.DataFrame({"S1": [4]}) assert_eq(return_df, expected_df) @pytest.mark.xfail(reason="WIP DataFusion") def test_group_by_case(c): return_df = c.sql( """ SELECT user_id + 1 AS "A", SUM(CASE WHEN b = 3 THEN 1 END) AS "S" FROM user_table_1 GROUP BY user_id + 1 """ ) expected_df = pd.DataFrame({"A": [2, 3, 4], "S": [1, 1, 1]}) # Do not check dtypes, as pandas versions are inconsistent here assert_eq( return_df.sort_values("A").reset_index(drop=True), expected_df, check_dtype=False, ) def test_group_by_nan(c, user_table_nan): return_df = c.sql( """ SELECT c FROM user_table_nan GROUP BY c """ ) expected_df = user_table_nan.drop_duplicates(subset=["c"]) # we return nullable int dtype instead of float assert_eq(return_df, expected_df, check_dtype=False) return_df = c.sql( """ SELECT c FROM user_table_inf GROUP BY c """ ) expected_df = pd.DataFrame({"c": [3, 1, float("inf")]}) expected_df["c"] = expected_df["c"].astype("float64") assert_eq( return_df.sort_values("c").reset_index(drop=True), expected_df.sort_values("c").reset_index(drop=True), ) def test_aggregations(c): return_df = c.sql( """ SELECT user_id, EVERY(b = 3) AS e, BIT_AND(b) AS b, BIT_OR(b) AS bb, MIN(b) AS m, SINGLE_VALUE(b) AS s, AVG(b) AS a FROM user_table_1 GROUP BY user_id """ ) expected_df = pd.DataFrame( { "user_id": [1, 2, 3], "e": [True, False, True], "b": [3, 1, 3], "bb": [3, 3, 3], "m": [3, 1, 3], "s": [3, 3, 3], "a": [3, 2, 3], } ) expected_df["a"] = expected_df["a"].astype("float64") assert_eq(return_df.sort_values("user_id").reset_index(drop=True), expected_df) return_df = c.sql( """ SELECT user_id, EVERY(c = 3) AS e, BIT_AND(c) AS b, BIT_OR(c) AS bb, MIN(c) AS m, SINGLE_VALUE(c) AS s, AVG(c) AS a FROM user_table_2 GROUP BY user_id """ ) expected_df = pd.DataFrame( { "user_id": [1, 2, 4], "e": [False, True, False], "b": [0, 3, 4], "bb": [3, 3, 4], "m": [1, 3, 4], "s": [1, 3, 4], "a": [1.5, 3, 4], } ) assert_eq(return_df.sort_values("user_id").reset_index(drop=True), expected_df) return_df = c.sql( """ SELECT MAX(a) AS "max", MIN(a) AS "min" FROM string_table """ ) expected_df = pd.DataFrame({"max": ["a normal string"], "min": ["%_%"]}) assert_eq(return_df.reset_index(drop=True), expected_df) @pytest.mark.parametrize( "gpu", [ False, pytest.param( True, marks=( pytest.mark.gpu, pytest.mark.xfail( reason="stddev_pop is failing on GPU, see https://github.com/dask-contrib/dask-sql/issues/681" ), ), ), ], ) def test_stddev(c, gpu): df = pd.DataFrame( { "a": [1, 1, 2, 1, 2], "b": [4, 6, 3, 8, 5], } ) c.create_table("df", df, gpu=gpu) return_df = c.sql( """ SELECT STDDEV(b) AS s FROM df GROUP BY df.a """ ) expected_df = pd.DataFrame({"s": df.groupby("a").std()["b"]}) assert_eq(return_df, expected_df, check_index=False) return_df = c.sql( """ SELECT STDDEV_SAMP(b) AS ss FROM df """ ) expected_df = pd.DataFrame({"ss": [df.std()["b"]]}) assert_eq(return_df, expected_df.reset_index(drop=True)) return_df = c.sql( """ SELECT STDDEV_POP(b) AS sp FROM df GROUP BY df.a """ ) expected_df = pd.DataFrame({"sp": df.groupby("a").std(ddof=0)["b"]}) assert_eq(return_df, expected_df.reset_index(drop=True)) return_df = c.sql( """ SELECT STDDEV(a) as s, STDDEV_SAMP(a) ss, STDDEV_POP(b) sp FROM df """ ) expected_df = pd.DataFrame( { "s": [df.std()["a"]], "ss": [df.std()["a"]], "sp": [df.std(ddof=0)["b"]], } ) assert_eq(return_df, expected_df.reset_index(drop=True)) c.drop_table("df") @pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)]) def test_regr_aggregation(c, timeseries_df, gpu): # test regr_count regr_count = c.sql( """ SELECT name, COUNT(x) FILTER (WHERE y IS NOT NULL) AS expected, REGR_COUNT(y, x) AS calculated FROM timeseries GROUP BY name """ ).fillna(0) assert_eq( regr_count["expected"], regr_count["calculated"], check_dtype=False, check_names=False, ) # test regr_syy regr_syy = c.sql( """ SELECT name, (REGR_COUNT(y, x) * VAR_POP(y)) AS expected, REGR_SYY(y, x) AS calculated FROM timeseries WHERE x IS NOT NULL AND y IS NOT NULL GROUP BY name """ ).fillna(0) assert_eq( regr_syy["expected"], regr_syy["calculated"], check_dtype=False, check_names=False, ) # test regr_sxx regr_sxx = c.sql( """ SELECT name, (REGR_COUNT(y, x) * VAR_POP(x)) AS expected, REGR_SXX(y,x) AS calculated FROM timeseries WHERE x IS NOT NULL AND y IS NOT NULL GROUP BY name """ ).fillna(0) assert_eq( regr_sxx["expected"], regr_sxx["calculated"], check_dtype=False, check_names=False, ) @pytest.mark.xfail( reason="WIP DataFusion - https://github.com/dask-contrib/dask-sql/issues/753" ) def test_covar_aggregation(c, timeseries_df): # test covar_pop covar_pop = c.sql( """ WITH temp_agg AS ( SELECT name, AVG(y) FILTER (WHERE x IS NOT NULL) as avg_y, AVG(x) FILTER (WHERE x IS NOT NULL) as avg_x FROM timeseries GROUP BY name ) SELECT ts.name, SUM((y - avg_y) * (x - avg_x)) / REGR_COUNT(y, x) AS expected, COVAR_POP(y,x) AS calculated FROM timeseries AS ts JOIN temp_agg AS ta ON ts.name = ta.name GROUP BY ts.name """ ).fillna(0) assert_eq( covar_pop["expected"], covar_pop["calculated"], check_dtype=False, check_names=False, ) # test covar_samp covar_samp = c.sql( """ WITH temp_agg AS ( SELECT name, AVG(y) FILTER (WHERE x IS NOT NULL) as avg_y, AVG(x) FILTER (WHERE x IS NOT NULL) as avg_x FROM timeseries GROUP BY name ) SELECT ts.name, SUM((y - avg_y) * (x - avg_x)) / (REGR_COUNT(y, x) - 1) as expected, COVAR_SAMP(y,x) AS calculated FROM timeseries AS ts JOIN temp_agg AS ta ON ts.name = ta.name GROUP BY ts.name """ ).fillna(0) assert_eq( covar_samp["expected"], covar_samp["calculated"], check_dtype=False, check_names=False, ) @pytest.mark.parametrize( "input_table", [ "user_table_1", pytest.param("gpu_user_table_1", marks=pytest.mark.gpu), ], ) @pytest.mark.parametrize("split_out", [1, 2, 4]) def test_groupby_split_out(c, input_table, split_out, request): user_table = request.getfixturevalue(input_table) return_df = c.sql( f""" SELECT user_id, SUM(b) AS "S" FROM {input_table} GROUP BY user_id """, config_options={"sql.aggregate.split_out": split_out} if split_out else {}, ) expected_df = ( user_table.groupby(by="user_id") .agg({"b": "sum"}) .reset_index(drop=False) .rename(columns={"b": "S"}) .sort_values("user_id") ) assert return_df.npartitions == split_out if split_out else 1 assert_eq(return_df.sort_values("user_id"), expected_df, check_index=False) return_df = c.sql( f""" SELECT DISTINCT(user_id) FROM {input_table} """, config_options={"sql.aggregate.split_out": split_out}, ) expected_df = user_table[["user_id"]].drop_duplicates() assert return_df.npartitions == split_out if split_out else 1 assert_eq(return_df.sort_values("user_id"), expected_df, check_index=False) @pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)]) def test_groupby_split_every(c, gpu): input_ddf = dd.from_pandas( pd.DataFrame({"user_id": [1, 2, 3, 4] * 16, "b": [5, 6, 7, 8] * 16}), npartitions=16, ) # Need an input with multiple partitions to demonstrate split_every c.create_table("split_every_input", input_ddf, gpu=gpu) query_string = """ SELECT user_id, SUM(b) AS "S" FROM split_every_input GROUP BY user_id """ split_every_2_df = c.sql( query_string, config_options={"sql.aggregate.split_every": 2}, ) split_every_3_df = c.sql( query_string, config_options={"sql.aggregate.split_every": 3}, ) split_every_4_df = c.sql( query_string, config_options={"sql.aggregate.split_every": 4}, ) expected_df = ( input_ddf.groupby(by="user_id") .agg({"b": "sum"}) .reset_index(drop=False) .rename(columns={"b": "S"}) .sort_values("user_id") ) assert ( len(split_every_2_df.dask.keys()) >= len(split_every_3_df.dask.keys()) >= len(split_every_4_df.dask.keys()) ) assert_eq(split_every_2_df, expected_df, check_index=False) assert_eq(split_every_3_df, expected_df, check_index=False) assert_eq(split_every_4_df, expected_df, check_index=False) query_string = """ SELECT DISTINCT(user_id) FROM split_every_input """ split_every_2_df = c.sql( query_string, config_options={"sql.aggregate.split_every": 2}, ) split_every_3_df = c.sql( query_string, config_options={"sql.aggregate.split_every": 3}, ) split_every_4_df = c.sql( query_string, config_options={"sql.aggregate.split_every": 4}, ) expected_df = input_ddf[["user_id"]].drop_duplicates() assert ( len(split_every_2_df.dask.keys()) >= len(split_every_3_df.dask.keys()) >= len(split_every_4_df.dask.keys()) ) assert_eq(split_every_2_df, expected_df, check_index=False) assert_eq(split_every_3_df, expected_df, check_index=False) assert_eq(split_every_4_df, expected_df, check_index=False) c.drop_table("split_every_input") @pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)]) def test_agg_decimal(c, gpu): df = pd.DataFrame( { "a": [1.23, 12.65, 134.64, -34.3, 945.19], "b": [1, 1, 2, 2, 3], } ) c.create_table("df", df, gpu=gpu) result_df = c.sql( """ SELECT SUM(CAST(a AS DECIMAL)) as s, COUNT(CAST(a AS DECIMAL)) as c, SUM(CAST(a+a AS DECIMAL)) as s2 FROM df GROUP BY b """ ) # decimal precision doesn't match up with pandas floats if gpu: result_df["s"] = result_df["s"].astype("float64") result_df["s2"] = result_df["s2"].astype("float64") expected_df = pd.DataFrame( { "s": df.groupby("b").sum()["a"], "c": df.groupby("b").count()["a"], "s2": df.groupby("b").sum()["a"] + df.groupby("b").sum()["a"], } ) # dtype of count aggregation is float on gpu assert_eq(result_df, expected_df, check_index=False, check_dtype=(not gpu)) result_df = c.sql( """ SELECT MIN(CAST(a AS DECIMAL)) as min, MAX(CAST(a AS DECIMAL)) as max FROM df """ ) # decimal precision doesn't match up with pandas floats if gpu: result_df["min"] = result_df["min"].astype("float64") result_df["max"] = result_df["max"].astype("float64") expected_df = pd.DataFrame( { "min": [df.a.min()], "max": [df.a.max()], } ) assert_eq(result_df, expected_df) c.drop_table("df") ================================================ FILE: tests/integration/test_hive.py ================================================ import shutil import sys import tempfile import time import pandas as pd import pytest from dask_sql.context import Context from tests.utils import assert_eq pytestmark = pytest.mark.xfail( condition=sys.platform in ("win32", "darwin"), reason="hive testing not supported on Windows/macOS", ) docker = pytest.importorskip("docker") sqlalchemy = pytest.importorskip("sqlalchemy") pytest.importorskip("pyhive") DEFAULT_CONFIG = { "HIVE_SITE_CONF_javax_jdo_option_ConnectionURL": "jdbc:postgresql://hive-metastore-postgresql/metastore", "HIVE_SITE_CONF_javax_jdo_option_ConnectionDriverName": "org.postgresql.Driver", "HIVE_SITE_CONF_javax_jdo_option_ConnectionUserName": "hive", "HIVE_SITE_CONF_javax_jdo_option_ConnectionPassword": "hive", "HIVE_SITE_CONF_datanucleus_autoCreateSchema": "false", "HIVE_SITE_CONF_hive_metastore_uris": "thrift://hive-metastore:9083", "HDFS_CONF_dfs_namenode_datanode_registration_ip___hostname___check": "false", "CORE_CONF_fs_defaultFS": "file:///database", "CORE_CONF_hadoop_http_staticuser_user": "root", "CORE_CONF_hadoop_proxyuser_hue_hosts": "*", "CORE_CONF_hadoop_proxyuser_hue_groups": "*", "HIVE_SITE_CONF_fs_default_name": "file:///database", "CORE_CONF_fs_defaultFS": "file:///database", "HIVE_SIZE_CONF_hive_metastore_warehouse_dir": "file:///database", } @pytest.fixture(scope="session") def hive_cursor(): """ Getting a hive setup up and running is a bit more complicated. We need three running docker containers: * a postgres database to store the metadata * the metadata server itself * and a server to answer SQL queries They are all started one after the other to check, if they are up and running. We "fake" a network filesystem (instead of using hdfs), by mounting a temporary folder from the host to the docker container, which can be accessed both by hive and the dask-sql client. We just need to make sure, to remove all containers, the network and the temporary folders correctly again. The ideas for the docker setup are taken from the docker-compose hive setup described by bde2020. """ client = docker.from_env() network = None hive_server = None hive_metastore = None hive_postgres = None tmpdir = tempfile.mkdtemp() tmpdir_parted = tempfile.mkdtemp() tmpdir_multiparted = tempfile.mkdtemp() try: network = client.networks.create("dask-sql-hive", driver="bridge") hive_server = client.containers.create( "bde2020/hive:2.3.2-postgresql-metastore", hostname="hive-server", name="hive-server", network="dask-sql-hive", volumes=[ f"{tmpdir}:{tmpdir}", f"{tmpdir_parted}:{tmpdir_parted}", f"{tmpdir_multiparted}:{tmpdir_multiparted}", ], environment={ "HIVE_CORE_CONF_javax_jdo_option_ConnectionURL": "jdbc:postgresql://hive-metastore-postgresql/metastore", **DEFAULT_CONFIG, }, ) hive_metastore = client.containers.create( "bde2020/hive:2.3.2-postgresql-metastore", hostname="hive-metastore", name="hive-metastore", network="dask-sql-hive", environment=DEFAULT_CONFIG, command="/opt/hive/bin/hive --service metastore", ) hive_postgres = client.containers.create( "bde2020/hive-metastore-postgresql:2.3.0", hostname="hive-metastore-postgresql", name="hive-metastore-postgresql", network="dask-sql-hive", ) # Wait for it to start hive_postgres.start() hive_postgres.exec_run(["bash"]) for l in hive_postgres.logs(stream=True): if b"ready for start up." in l: break hive_metastore.start() hive_metastore.exec_run(["bash"]) for l in hive_metastore.logs(stream=True): if b"Starting hive metastore" in l: break hive_server.start() hive_server.exec_run(["bash"]) for l in hive_server.logs(stream=True): if b"Starting HiveServer2" in l: break # The server needs some time to start. # It is easier to check for the first access # on the metastore than to wait some # arbitrary time. for l in hive_metastore.logs(stream=True): if b"get_multi_table" in l: break time.sleep(2) hive_server.reload() address = hive_server.attrs["NetworkSettings"]["Networks"]["dask-sql-hive"][ "IPAddress" ] port = 10000 cursor = sqlalchemy.create_engine(f"hive://{address}:{port}").connect() # Create a non-partitioned column cursor.execute( sqlalchemy.text( f"CREATE TABLE df (i INTEGER, j INTEGER) ROW FORMAT DELIMITED STORED AS PARQUET LOCATION '{tmpdir}'" ) ) cursor.execute(sqlalchemy.text("INSERT INTO df (i, j) VALUES (1, 2)")) cursor.execute(sqlalchemy.text("INSERT INTO df (i, j) VALUES (2, 4)")) cursor.execute( sqlalchemy.text( f"CREATE TABLE df_part (i INTEGER) PARTITIONED BY (j INTEGER) ROW FORMAT DELIMITED STORED AS PARQUET LOCATION '{tmpdir_parted}'" ) ) cursor.execute( sqlalchemy.text("INSERT INTO df_part PARTITION (j=2) (i) VALUES (1)") ) cursor.execute( sqlalchemy.text("INSERT INTO df_part PARTITION (j=4) (i) VALUES (2)") ) cursor.execute( sqlalchemy.text( f""" CREATE TABLE df_parts (i INTEGER) PARTITIONED BY (j INTEGER, k STRING) ROW FORMAT DELIMITED STORED AS PARQUET LOCATION '{tmpdir_multiparted}' """ ) ) cursor.execute( sqlalchemy.text( "INSERT INTO df_parts PARTITION (j=1, k='a') (i) VALUES (1)" ) ) cursor.execute( sqlalchemy.text( "INSERT INTO df_parts PARTITION (j=2, k='b') (i) VALUES (2)" ) ) # The data files are created as root user by default. Change that: hive_server.exec_run(["chmod", "a+rwx", "-R", tmpdir]) hive_server.exec_run(["chmod", "a+rwx", "-R", tmpdir_parted]) hive_server.exec_run(["chmod", "a+rwx", "-R", tmpdir_multiparted]) yield cursor except docker.errors.ImageNotFound: pytest.skip( "Hive testing requires 'bde2020/hive:2.3.2-postgresql-metastore' and " "'bde2020/hive-metastore-postgresql:2.3.0' docker images" ) finally: # Now clean up: remove the containers and the network and the folders for container in [hive_server, hive_metastore, hive_postgres]: if container is None: continue try: container.kill() except Exception: pass container.remove() if network is not None: network.remove() shutil.rmtree(tmpdir) shutil.rmtree(tmpdir_parted) def test_select(hive_cursor): c = Context() c.create_table("df", hive_cursor) result_df = c.sql("SELECT * FROM df") expected_df = pd.DataFrame({"i": [1, 2], "j": [2, 4]}).astype("int32") assert_eq(result_df, expected_df, check_index=False) def test_select_partitions(hive_cursor): c = Context() c.create_table("df_part", hive_cursor) result_df = c.sql("SELECT * FROM df_part") expected_df = pd.DataFrame({"i": [1, 2], "j": [2, 4]}).astype("int32") expected_df["j"] = expected_df["j"].astype("int64") assert_eq(result_df, expected_df, check_index=False) def test_select_multipartitions(hive_cursor): c = Context() c.create_table("df_parts", hive_cursor) result_df = c.sql("SELECT * FROM df_parts") expected_df = pd.DataFrame({"i": [1, 2], "j": [1, 2], "k": ["a", "b"]}) expected_df["i"] = expected_df["i"].astype("int32") expected_df["j"] = expected_df["j"].astype("int64") expected_df["k"] = expected_df["k"].astype("object") assert_eq(result_df, expected_df, check_index=False) ================================================ FILE: tests/integration/test_intake.py ================================================ import os import shutil import tempfile import pandas as pd import pytest from dask_sql.context import Context from tests.utils import assert_eq, skipif_dask_expr_enabled # intake doesn't yet have proper dask-expr support pytestmark = skipif_dask_expr_enabled( reason="Intake doesn't yet have proper dask-expr support" ) # skip the test if intake is not installed intake = pytest.importorskip("intake") @pytest.fixture() def intake_catalog_location(): tmpdir = tempfile.mkdtemp() df = pd.DataFrame({"a": [1], "b": [1.5]}) csv_location = os.path.join(tmpdir, "data.csv") df.to_csv(csv_location, index=False) yaml_location = os.path.join(tmpdir, "catalog.yaml") with open(yaml_location, "w") as f: f.write( """sources: intake_table: args: urlpath: "{{ CATALOG_DIR }}/data.csv" description: "Some Data" driver: intake.source.csv.CSVSource """ ) try: yield yaml_location finally: shutil.rmtree(tmpdir) def check_read_table(c): result_df = c.sql("SELECT * FROM df").reset_index(drop=True) expected_df = pd.DataFrame({"a": [1], "b": [1.5]}) assert_eq(result_df, expected_df) def test_intake_catalog(intake_catalog_location): catalog = intake.open_catalog(intake_catalog_location) c = Context() c.create_table("df", catalog, intake_table_name="intake_table") check_read_table(c) def test_intake_location(intake_catalog_location): c = Context() c.create_table( "df", intake_catalog_location, format="intake", intake_table_name="intake_table" ) check_read_table(c) def test_intake_sql(intake_catalog_location): c = Context() c.sql( f""" CREATE TABLE df WITH ( location = '{intake_catalog_location}', format = 'intake', intake_table_name = 'intake_table' ) """ ) check_read_table(c) ================================================ FILE: tests/integration/test_jdbc.py ================================================ from time import sleep import pandas as pd import pytest from dask_sql import Context from dask_sql.server.app import _init_app, app from dask_sql.server.presto_jdbc import create_meta_data from tests.integration.fixtures import DISTRIBUTED_TESTS # needed for the testclient pytest.importorskip("requests") schema = "a_schema" table = "a_table" @pytest.fixture(scope="module") def c(): c = Context() c.create_schema(schema) tables = pd.DataFrame(create_table_row(), index=[0]) tables = tables.astype({"AN_INT": "int64"}) c.create_table(table, tables, schema_name=schema) yield c c.drop_schema(schema) @pytest.fixture(scope="module") def app_client(c): c.sql("SELECT 1 + 1").compute() _init_app(app, c) # late import for the importskip from fastapi.testclient import TestClient yield TestClient(app) # avoid closing client it's session-wide if not DISTRIBUTED_TESTS: app.client.close() @pytest.mark.xfail(reason="WIP DataFusion") def test_jdbc_has_schema(app_client, c): create_meta_data(c) check_data(app_client) response = app_client.post( "/v1/statement", data="SELECT * from system.jdbc.schemas" ) assert response.status_code == 200 result = get_result_or_error(app_client, response) assert_result(result, 2, 3) assert result["columns"] == [ { "name": "TABLE_CATALOG", "type": "varchar", "typeSignature": {"rawType": "varchar", "arguments": []}, }, { "name": "TABLE_SCHEM", "type": "varchar", "typeSignature": {"rawType": "varchar", "arguments": []}, }, ] assert result["data"] == [ ["", "root"], ["", "a_schema"], ["", "system_jdbc"], ] def test_jdbc_has_table(app_client, c): create_meta_data(c) check_data(app_client) response = app_client.post("/v1/statement", data="SELECT * from system.jdbc.tables") assert response.status_code == 200 result = get_result_or_error(app_client, response) assert_result(result, 10, 4) assert result["data"] == [ ["", "a_schema", "a_table", "", "", "", "", "", "", ""], ["", "system_jdbc", "schemas", "", "", "", "", "", "", ""], ["", "system_jdbc", "tables", "", "", "", "", "", "", ""], ["", "system_jdbc", "columns", "", "", "", "", "", "", ""], ] @pytest.mark.xfail(reason="WIP DataFusion") def test_jdbc_has_columns(app_client, c): create_meta_data(c) check_data(app_client) response = app_client.post( "/v1/statement", data=f"SELECT * from system.jdbc.columns where TABLE_NAME = '{table}'", ) assert response.status_code == 200 client_result = get_result_or_error(app_client, response) # ordering of rows isn't consistent between fastapi versions context_result = ( c.sql("SELECT * FROM system_jdbc.columns WHERE TABLE_NAME = 'a_table'") .compute() .values.tolist() ) assert_result(client_result, 24, 3) assert client_result["data"] == context_result def assert_result(result, col_len, data_len): assert "columns" in result assert "data" in result assert "error" not in result assert len(result["columns"]) == col_len assert len(result["data"]) == data_len def create_table_row(a_str: str = "any", an_int: int = 1, a_float: float = 1.1): return { "A_STR": a_str, "AN_INT": an_int, "A_FLOAT": a_float, } def check_data(app_client): response = app_client.post("/v1/statement", data=f"SELECT * from {schema}.{table}") assert response.status_code == 200 a_table = get_result_or_error(app_client, response) assert "columns" in a_table assert "data" in a_table assert "error" not in a_table def get_result_or_error(app_client, response): result = response.json() assert "nextUri" in result assert "error" not in result status_url = result["nextUri"] next_url = status_url counter = 0 while True: response = app_client.get(next_url) assert response.status_code == 200 result = response.json() if "nextUri" not in result: break next_url = result["nextUri"] counter += 1 assert counter <= 100 sleep(0.1) return result ================================================ FILE: tests/integration/test_join.py ================================================ from contextlib import nullcontext import dask.dataframe as dd import numpy as np import pandas as pd import pytest from dask.utils_test import hlg_layer from dask_sql import Context from dask_sql.datacontainer import Statistics from tests.utils import assert_eq, skipif_dask_expr_enabled def test_join(c): return_df = c.sql( """ SELECT lhs.user_id, lhs.b, rhs.c FROM user_table_1 AS lhs JOIN user_table_2 AS rhs ON lhs.user_id = rhs.user_id """ ) expected_df = pd.DataFrame( {"user_id": [1, 1, 2, 2], "b": [3, 3, 1, 3], "c": [1, 2, 3, 3]} ) assert_eq(return_df, expected_df, check_index=False) def test_join_inner(c): return_df = c.sql( """ SELECT lhs.user_id, lhs.b, rhs.c FROM user_table_1 AS lhs INNER JOIN user_table_2 AS rhs ON lhs.user_id = rhs.user_id """ ) expected_df = pd.DataFrame( {"user_id": [1, 1, 2, 2], "b": [3, 3, 1, 3], "c": [1, 2, 3, 3]} ) assert_eq(return_df, expected_df, check_index=False) def test_join_outer(c): return_df = c.sql( """ SELECT lhs.user_id, lhs.b, rhs.c FROM user_table_1 AS lhs FULL JOIN user_table_2 AS rhs ON lhs.user_id = rhs.user_id """ ) expected_df = pd.DataFrame( { # That is strange. Unfortunately, it seems dask fills in the # missing rows with NaN, not with NA... "user_id": [1, 1, 2, 2, 3, np.NaN], "b": [3, 3, 1, 3, 3, np.NaN], "c": [1, 2, 3, 3, np.NaN, 4], } ) assert_eq(return_df, expected_df, check_index=False) def test_join_left(c): return_df = c.sql( """ SELECT lhs.user_id, lhs.b, rhs.c FROM user_table_1 AS lhs LEFT JOIN user_table_2 AS rhs ON lhs.user_id = rhs.user_id """ ) expected_df = pd.DataFrame( { # That is strange. Unfortunately, it seems dask fills in the # missing rows with NaN, not with NA... "user_id": [1, 1, 2, 2, 3], "b": [3, 3, 1, 3, 3], "c": [1, 2, 3, 3, np.NaN], } ) assert_eq(return_df, expected_df, check_index=False) @pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)]) def test_join_left_anti(c, gpu): df1 = pd.DataFrame({"id": [1, 1, 2, 4], "a": ["a", "b", "c", "d"]}) df2 = pd.DataFrame({"id": [2, 1, 2, 3], "b": ["c", "c", "a", "c"]}) c.create_table("df_1", df1, gpu=gpu) c.create_table("df_2", df2, gpu=gpu) return_df = c.sql( """ SELECT lhs.id, lhs.a FROM df_1 AS lhs LEFT ANTI JOIN df_2 AS rhs ON lhs.id = rhs.id """ ) expected_df = pd.DataFrame( { "id": [4], "a": ["d"], } ) assert_eq(return_df, expected_df, check_index=False) @pytest.mark.gpu def test_join_left_semi(c): df1 = pd.DataFrame({"id": [1, 1, 2, 4], "a": ["a", "b", "c", "d"]}) df2 = pd.DataFrame({"id": [2, 1, 2, 3], "b": ["c", "c", "a", "c"]}) c.create_table("df_1", df1, gpu=True) c.create_table("df_2", df2, gpu=True) return_df = c.sql( """ SELECT lhs.id, lhs.a FROM df_1 AS lhs LEFT SEMI JOIN df_2 AS rhs ON lhs.id = rhs.id """ ) expected_df = pd.DataFrame( { "id": [1, 1, 2], "a": ["a", "b", "c"], } ) assert_eq(return_df, expected_df, check_index=False) def test_join_right(c): return_df = c.sql( """ SELECT lhs.user_id, lhs.b, rhs.c FROM user_table_1 AS lhs RIGHT JOIN user_table_2 AS rhs ON lhs.user_id = rhs.user_id """ ) expected_df = pd.DataFrame( { # That is strange. Unfortunately, it seems dask fills in the # missing rows with NaN, not with NA... "user_id": [1, 1, 2, 2, np.NaN], "b": [3, 3, 1, 3, np.NaN], "c": [1, 2, 3, 3, 4], } ) assert_eq(return_df, expected_df, check_index=False) def test_join_cross(c, user_table_1, department_table): return_df = c.sql( """ SELECT user_id, b, department_name FROM user_table_1, department_table """ ) user_table_1["key"] = 1 department_table["key"] = 1 expected_df = dd.merge(user_table_1, department_table, on="key").drop(columns="key") assert_eq(return_df, expected_df, check_index=False) def test_join_complex(c): return_df = c.sql( """ SELECT lhs.a, rhs.b FROM df_simple AS lhs JOIN df_simple AS rhs ON lhs.a < rhs.b """ ) expected_df = pd.DataFrame( {"a": [1, 1, 1, 2, 2, 3], "b": [1.1, 2.2, 3.3, 2.2, 3.3, 3.3]} ) assert_eq(return_df, expected_df, check_index=False) return_df = c.sql( """ SELECT lhs.a, lhs.b, rhs.a, rhs.b FROM df_simple AS lhs JOIN df_simple AS rhs ON lhs.a < rhs.b AND lhs.b < rhs.a """ ) expected_df = pd.DataFrame( { "lhs.a": [1, 1, 2], "lhs.b": [1.1, 1.1, 2.2], "rhs.a": [2, 3, 3], "rhs.b": [2.2, 3.3, 3.3], } ) assert_eq(return_df, expected_df, check_index=False) return_df = c.sql( """ SELECT lhs.user_id, lhs.b, rhs.user_id, rhs.c FROM user_table_1 AS lhs JOIN user_table_2 AS rhs ON rhs.user_id = lhs.user_id AND rhs.c - lhs.b >= 0 """ ) expected_df = pd.DataFrame( {"lhs.user_id": [2, 2], "b": [1, 3], "rhs.user_id": [2, 2], "c": [3, 3]} ) assert_eq(return_df, expected_df, check_index=False) def test_join_literal(c): return_df = c.sql( """ SELECT lhs.user_id, lhs.b, rhs.user_id, rhs.c FROM user_table_1 AS lhs JOIN user_table_2 AS rhs ON True """ ) expected_df = pd.DataFrame( { "lhs.user_id": [2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3], "b": [1, 1, 1, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3], "rhs.user_id": [1, 1, 2, 4, 1, 1, 2, 4, 1, 1, 2, 4, 1, 1, 2, 4], "c": [1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4], } ) assert_eq(return_df, expected_df, check_index=False) return_df = c.sql( """ SELECT lhs.user_id, lhs.b, rhs.user_id, rhs.c FROM user_table_1 AS lhs JOIN user_table_2 AS rhs ON False """ ) expected_df = pd.DataFrame({"lhs.user_id": [], "b": [], "rhs.user_id": [], "c": []}) assert_eq(return_df, expected_df, check_dtype=False, check_index=False) def test_conditional_join(c): df1 = pd.DataFrame({"a": [1, 2, 2, 5, 6], "b": ["w", "x", "y", None, "z"]}) df2 = pd.DataFrame({"c": [None, 3, 2, 5], "d": ["h", "i", "j", "k"]}) expected_df = pd.merge(df1, df2, how="inner", left_on=["a"], right_on=["c"]) expected_df = expected_df[~pd.isnull(expected_df.b)] c.create_table("df1", df1) c.create_table("df2", df2) actual_df = c.sql( """ SELECT * FROM df1 INNER JOIN df2 ON ( a = c AND b IS NOT NULL ) """ ) assert_eq(actual_df, expected_df, check_index=False, check_dtype=False) def test_join_on_unary_cond_only(c): df1 = pd.DataFrame({"a": [1, 2, 2, 5, 6], "b": ["w", "x", "y", None, "z"]}) df2 = pd.DataFrame({"c": [None, 3, 2, 5], "d": ["h", "i", "j", "k"]}) c.create_table("df1", df1) c.create_table("df2", df2) df1 = df1.assign(common=1) df2 = df2.assign(common=1) expected_df = df1.merge(df2, on="common").drop(columns="common") expected_df = expected_df[~pd.isnull(expected_df.b)] actual_df = c.sql("SELECT * FROM df1 INNER JOIN df2 ON b IS NOT NULL") assert_eq(actual_df, expected_df, check_index=False, check_dtype=False) def test_join_case_projection_subquery(): c = Context() # Tables for query demo = pd.DataFrame({"demo_sku": [], "hd_dep_count": []}) site_page = pd.DataFrame({"site_page_sk": [], "site_char_count": []}) sales = pd.DataFrame( {"sales_hdemo_sk": [], "sales_page_sk": [], "sold_time_sk": []} ) t_dim = pd.DataFrame({"t_time_sk": [], "t_hour": []}) c.create_table("demos", demo, persist=False) c.create_table("site_page", site_page, persist=False) c.create_table("sales", sales, persist=False) c.create_table("t_dim", t_dim, persist=False) c.sql( """ SELECT CASE WHEN pmc > 0.0 THEN CAST (amc AS DOUBLE) / CAST (pmc AS DOUBLE) ELSE -1.0 END AS am_pm_ratio FROM ( SELECT SUM(amc1) AS amc, SUM(pmc1) AS pmc FROM ( SELECT CASE WHEN t_hour BETWEEN 7 AND 8 THEN COUNT(1) ELSE 0 END AS amc1, CASE WHEN t_hour BETWEEN 19 AND 20 THEN COUNT(1) ELSE 0 END AS pmc1 FROM sales ws JOIN demos hd ON (hd.demo_sku = ws.sales_hdemo_sk and hd.hd_dep_count = 5) JOIN site_page sp ON (sp.site_page_sk = ws.sales_page_sk and sp.site_char_count BETWEEN 5000 AND 6000) JOIN t_dim td ON (td.t_time_sk = ws.sold_time_sk and td.t_hour IN (7,8,19,20)) GROUP BY t_hour ) cnt_am_pm ) sum_am_pm """ ).compute() def test_conditional_join_with_limit(c): df = pd.DataFrame({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]}) ddf = dd.from_pandas(df, 5) c.create_table("many_partitions", ddf) df = df.assign(common=1) expected_df = df.merge(df, on="common", suffixes=("", "0")).drop(columns="common") expected_df = expected_df[expected_df["a"] >= 2][:4] # Columns are renamed to use their fully qualified names which is more accurate expected_df = expected_df.rename( columns={"a": "df1.a", "b": "df1.b", "a0": "df2.a", "b0": "df2.b"} ) actual_df = c.sql( """ SELECT * FROM many_partitions as df1, many_partitions as df2 WHERE df1."a" >= 2 LIMIT 4 """ ) assert_eq(actual_df, expected_df, check_index=False) @pytest.mark.filterwarnings( "ignore:You are merging on int and float:UserWarning:dask.dataframe.multi" ) def test_intersect(c): # Join df_simple against itself actual_df = c.sql( """ select count(*) from ( select * from df_simple intersect select * from df_simple ) hot_item limit 100 """ ) assert actual_df["COUNT(*)"].compute()[0] == 3 # Join df_simple against itself, and then that result against df_wide. Nothing should match so therefore result should be 0 actual_df = c.sql( """ select count(*) from ( select a, b from df_simple intersect select a, b from df_simple intersect select a, b from df_wide ) hot_item limit 100 """ ) assert len(actual_df["COUNT(*)"]) == 0 actual_df = c.sql( """ select * from df_simple intersect select * from df_simple """ ) assert actual_df.shape[0].compute() == 3 def test_intersect_multi_col(c): df1 = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) df2 = pd.DataFrame({"a": [1, 1, 1], "b": [4, 5, 6], "c": [7, 7, 7]}) c.create_table("df1", df1) c.create_table("df2", df2) return_df = c.sql("select * from df1 intersect select * from df2") expected_df = pd.DataFrame( { "df1.a": [1], "df1.b": [4], "df1.c": [7], "df2.a": [1], "df2.b": [4], "df2.c": [7], } ) assert_eq(return_df, expected_df, check_index=False) # TODO: remove this marker once fix for dask-expr#1018 is released # see: https://github.com/dask/dask-expr/issues/1018 @skipif_dask_expr_enabled("Waiting for fix to dask-expr#1018") def test_join_alias_w_projection(c, parquet_ddf): result_df = c.sql( "SELECT t2.c as c_y from parquet_ddf t1, parquet_ddf t2 WHERE t1.a=t2.a and t1.c='A'" ) expected_df = parquet_ddf.merge(parquet_ddf, on=["a"], how="inner") expected_df = expected_df[expected_df["c_x"] == "A"][["c_y"]] assert_eq(result_df, expected_df, check_index=False) def test_filter_columns_post_join(c): df = pd.DataFrame({"a": [1, 2, 3, 4, 5], "c": [1, None, 2, 2, 2]}) df2 = pd.DataFrame({"b": [1, 1, 2, 2, 3], "c": [2, 2, 2, 2, 2]}) c.create_table("df", df) c.create_table("df2", df2) query = "SELECT SUM(df.a) as sum_a, df2.b FROM df INNER JOIN df2 ON df.c=df2.c GROUP BY df2.b" explain_string = c.explain(query) assert ("Projection: df.a, df2.b" in explain_string) or ( "Projection: df2.b, df.a" in explain_string ) result_df = c.sql(query) expected_df = pd.DataFrame({"sum_a": [24, 24, 12], "b": [1, 2, 3]}) assert_eq(result_df, expected_df) def test_join_reorder(c): df = pd.DataFrame({"a1": [1, 2, 3, 4, 5] * 2, "a2": [1, 1, 2, 2, 2] * 2}) df2 = pd.DataFrame({"b1": [1, 1, 2, 2, 3] * 10000, "b2": [2, 2, 2, 2, 2] * 10000}) df3 = pd.DataFrame({"c2": [1, 1, 2, 2, 3], "c3": [2, 3, 4, 5, 6]}) c.create_table("a", df, statistics=Statistics(10)) c.create_table("b", df2, statistics=Statistics(50000)) c.create_table("c", df3, statistics=Statistics(5)) # Basic join reorder test query = """ SELECT a1, b2, c3 FROM a, b, c WHERE b1 < 3 AND c3 < 5 AND a1 = b1 AND b2 = c2 """ explain_string = c.explain(query) first_join = "Inner Join: b.b2 = c.c2" second_join = "Inner Join: b.b1 = a.a1" """ LogicalPlan is expected to look something like: Limit: skip=0, fetch=10 Projection: a.a1, b.b2, c.c3 Inner Join: b.b1 = a.a1 Projection: b.b1, b.b2, c.c3 Inner Join: b.b2 = c.c2 Projection: b.b1, b.b2 TableScan: b projection=[b1, b2], full_filters=[b.b1 < Int64(3), b.b2 IS NOT NULL, b.b1 IS NOT NULL] Projection: c.c2, c.c3 TableScan: c projection=[c2, c3], full_filters=[c.c3 < Int64(5), c.c2 IS NOT NULL] Projection: a.a1 TableScan: a projection=[a1], full_filters=[a.a1 < Int64(3), a.a1 IS NOT NULL] So the a-b join is expected to appear earlier in the string than the b-c join """ assert first_join in explain_string and second_join in explain_string assert explain_string.index(second_join) < explain_string.index(first_join) result_df = c.sql(query) merged_df = df.merge(df2, left_on="a1", right_on="b1").merge( df3, left_on="b2", right_on="c2" ) expected_df = merged_df[(merged_df["b1"] < 3) & (merged_df["c3"] < 5)][ ["a1", "b2", "c3"] ] assert_eq(result_df, expected_df, check_index=False) # By default, join reordering should NOT reorder unfiltered dimension tables query = """ SELECT a1, b2, c3 FROM a, b, c WHERE a1 = b1 AND b2 = c2 """ explain_string = c.explain(query) first_join = "Inner Join: b.b1 = a.a1" second_join = "Inner Join: b.b2 = c.c2" assert first_join in explain_string and second_join in explain_string assert explain_string.index(second_join) < explain_string.index(first_join) result_df = c.sql(query) expected_df = df.merge(df2, left_on="a1", right_on="b1").merge( df3, left_on="b2", right_on="c2" )[["a1", "b2", "c3"]] assert_eq(result_df, expected_df, check_index=False) def check_broadcast_join(df, val, raises=False): """ Check that the broadcast join is correctly set in the Dask layer or expression graph Parameters ---------- df : DataFrame The DataFrame to check val : bool or float The expected value of the broadcast join raises : bool, optional Whether the legacy Dask check should raise an error if the broadcast join is not set """ if dd._dask_expr_enabled(): from dask_expr._merge import Merge merge_ops = [op for op in df.expr.find_operations(Merge)] assert len(merge_ops) == 1 assert merge_ops[0].broadcast == val else: with pytest.raises(KeyError) if raises else nullcontext(): assert hlg_layer(df.dask, "bcast-join") @pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)]) def test_broadcast_join(c, client, gpu): df1 = dd.from_pandas( pd.DataFrame({"user_id": [1, 2, 3, 4], "b": [5, 6, 7, 8]}), npartitions=2, ) df2 = dd.from_pandas( pd.DataFrame({"user_id": [1, 2, 3, 4] * 4, "c": [5, 6, 7, 8] * 4}), npartitions=8, ) c.create_table("df1", df1, gpu=gpu) c.create_table("df2", df2, gpu=gpu) query_string = """ SELECT df1.user_id as user_id, b, c FROM df1, df2 WHERE df1.user_id = df2.user_id """ expected_df = df1.merge(df2, on="user_id", how="inner") res_df = c.sql(query_string, config_options={"sql.join.broadcast": True}) check_broadcast_join(res_df, True) assert_eq( res_df, expected_df, check_divisions=False, check_index=False, scheduler="distributed", ) res_df = c.sql(query_string, config_options={"sql.join.broadcast": 1.0}) check_broadcast_join(res_df, 1.0) assert_eq( res_df, expected_df, check_divisions=False, check_index=False, scheduler="distributed", ) res_df = c.sql(query_string, config_options={"sql.join.broadcast": 0.5}) check_broadcast_join(res_df, 0.5, raises=True) assert_eq(res_df, expected_df, check_index=False, scheduler="distributed") res_df = c.sql(query_string, config_options={"sql.join.broadcast": False}) check_broadcast_join(res_df, False, raises=True) assert_eq(res_df, expected_df, check_index=False, scheduler="distributed") res_df = c.sql(query_string, config_options={"sql.join.broadcast": None}) check_broadcast_join(res_df, None, raises=True) assert_eq(res_df, expected_df, check_index=False, scheduler="distributed") @pytest.mark.gpu def test_null_key_join(c): df1 = pd.DataFrame({"a": [None, None, None, None, None, 1]}) df2 = pd.DataFrame({"b": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]}) c.create_table("df1", df1, gpu=True) c.create_table("df2", df2, gpu=True) result_df = c.sql( "SELECT * FROM (select * from df1 limit 5) JOIN (select * from df2 limit 5) ON a=b" ) expected_df = pd.DataFrame({"a": [], "b": []}) assert_eq(result_df, expected_df) ================================================ FILE: tests/integration/test_model.py ================================================ import os import pickle import sys import joblib import pandas as pd import pytest from packaging.version import parse as parseVersion from tests.utils import assert_eq try: import cuml import dask_cudf import xgboost except ImportError: cuml = None xgboost = None dask_cudf = None sklearn = pytest.importorskip("sklearn") SKLEARN_EQ_140 = parseVersion(sklearn.__version__) == parseVersion("1.4.0") def check_trained_model(c, model_name="my_model", df_name="timeseries"): sql = f""" SELECT * FROM PREDICT( MODEL {model_name}, SELECT x, y FROM {df_name} ) """ tables_before = c.schema["root"].tables.keys() result_df = c.sql(sql).compute() # assert that there are no additional tables in context from prediction assert tables_before == c.schema["root"].tables.keys() assert "target" in result_df.columns assert len(result_df["target"]) > 0 @pytest.mark.parametrize( "gpu_client", [False, pytest.param(True, marks=pytest.mark.gpu)], indirect=True ) def test_training_and_prediction(c, gpu_client): gpu = "CUDA" in str(gpu_client.cluster) timeseries = "gpu_timeseries" if gpu else "timeseries" # cuML does not have a GradientBoostingClassifier if not gpu: c.sql( """ CREATE MODEL my_model WITH ( model_class = 'GradientBoostingClassifier', wrap_predict = True, target_column = 'target' ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 ) """ ) check_trained_model(c) c.sql( f""" CREATE OR REPLACE MODEL my_model WITH ( model_class = 'LogisticRegression', wrap_predict = True, wrap_fit = False, target_column = 'target' ) AS ( SELECT x, y, x*y > 0 AS target FROM {timeseries} ) """ ) check_trained_model(c, df_name=timeseries) c.sql( f""" CREATE OR REPLACE MODEL my_model WITH ( model_class = 'LinearRegression', target_column = 'target' ) AS ( SELECT x, y, x*y AS target FROM {timeseries} ) """ ) check_trained_model(c, df_name=timeseries) # TODO: investigate deadlocks on GPU @pytest.mark.xfail( sys.platform in ("darwin", "win32"), reason="Intermittent failures on macOS/Windows", strict=False, ) @pytest.mark.parametrize( "gpu_client", [ False, pytest.param( True, marks=(pytest.mark.gpu, pytest.mark.skip(reason="Deadlocks on GPU")) ), ], indirect=True, ) def test_xgboost_training_prediction(c, gpu_client): gpu = "CUDA" in str(gpu_client.cluster) timeseries = "gpu_timeseries" if gpu else "timeseries" # TODO: XGBClassifiers error on GPU if not gpu: c.sql( """ CREATE OR REPLACE MODEL my_model WITH ( model_class = 'DaskXGBClassifier', target_column = 'target' ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 ) """ ) check_trained_model(c) c.sql( """ CREATE OR REPLACE MODEL my_model WITH ( model_class = 'XGBClassifier', target_column = 'target' ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 ) """ ) check_trained_model(c) # For GPU tests, set tree_method = 'gpu_hist' tree_method = "gpu_hist" if gpu else "hist" c.sql( f""" CREATE OR REPLACE MODEL my_model WITH ( model_class = 'DaskXGBRegressor', target_column = 'target', tree_method = '{tree_method}' ) AS ( SELECT x, y, x*y AS target FROM {timeseries} ) """ ) check_trained_model(c, df_name=timeseries) c.sql( f""" CREATE OR REPLACE MODEL my_model WITH ( model_class = 'XGBRegressor', wrap_predict = True, target_column = 'target', tree_method = '{tree_method}' ) AS ( SELECT x, y, x*y AS target FROM {timeseries} ) """ ) check_trained_model(c, df_name=timeseries) @pytest.mark.parametrize( "gpu_client", [False, pytest.param(True, marks=pytest.mark.gpu)], indirect=True ) def test_clustering_and_prediction(c, gpu_client): gpu = "CUDA" in str(gpu_client.cluster) timeseries = "gpu_timeseries" if gpu else "timeseries" c.sql( f""" CREATE MODEL my_model WITH ( model_class = 'KMeans' ) AS ( SELECT x, y FROM {timeseries} LIMIT 100 ) """ ) check_trained_model(c, df_name=timeseries) def test_create_model_with_prediction(c): c.sql( """ CREATE MODEL my_model1 WITH ( model_class = 'GradientBoostingClassifier', wrap_predict = True, target_column = 'target' ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 ) """ ) c.sql( """ CREATE MODEL my_model2 WITH ( model_class = 'GradientBoostingClassifier', wrap_predict = True, target_column = 'target' ) AS ( SELECT * FROM PREDICT ( MODEL my_model1, SELECT x, y FROM timeseries LIMIT 100 ) ) """ ) check_trained_model(c, "my_model2") def test_iterative_and_prediction(c): c.sql( """ CREATE MODEL my_model WITH ( model_class = 'SGDClassifier', wrap_fit = True, target_column = 'target', fit_kwargs = ( classes = ARRAY [0, 1] ) ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 ) """ ) check_trained_model(c) def test_show_models(c): c.sql( """ CREATE MODEL my_model1 WITH ( model_class = 'GradientBoostingClassifier', wrap_predict = True, target_column = 'target' ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 ) """ ) c.sql( """ CREATE MODEL my_model2 WITH ( model_class = 'KMeans' ) AS ( SELECT x, y FROM timeseries LIMIT 100 ) """ ) c.sql( """ CREATE MODEL my_model3 WITH ( model_class = 'SGDClassifier', wrap_fit = True, target_column = 'target', fit_kwargs = ( classes = ARRAY [0, 1] ) ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 ) """ ) result = c.sql("SHOW MODELS") expected = pd.DataFrame(["my_model1", "my_model2", "my_model3"], columns=["Models"]) assert_eq(result, expected) def test_wrong_training_or_prediction(c): with pytest.raises(KeyError): c.sql( """ SELECT * FROM PREDICT( MODEL my_model, SELECT x, y FROM timeseries ) """ ) with pytest.raises(ValueError): c.sql( """ CREATE MODEL my_model WITH ( target_column = 'target' ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 ) """ ) with pytest.raises(ImportError): c.sql( """ CREATE MODEL my_model WITH ( model_class = 'that.is.not.a.python.class', target_column = 'target' ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 ) """ ) def test_correct_argument_passing(c): c.sql( """ CREATE MODEL my_model WITH ( model_class = 'mock.MagicMock', target_column = 'target', fit_kwargs = ( single_quoted_string = 'hello', double_quoted_string = "hi", integer = -300, float = 23.45, boolean = False, array = ARRAY [ 1, 2 ], dict = MAP [ 'a', 1 ], set = MULTISET [ 1, 1, 2, 3 ] ) ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 ) """ ) mocked_model, columns = c.schema[c.schema_name].models["my_model"] assert list(columns) == ["x", "y"] fit_function = mocked_model.fit fit_function.assert_called_once() call_kwargs = fit_function.call_args.kwargs assert call_kwargs == dict( single_quoted_string="hello", double_quoted_string="hi", integer=-300, float=23.45, boolean=False, array=[1, 2], dict={"a": 1}, set={1, 2, 3}, ) def test_replace_and_error(c): c.sql( """ CREATE MODEL my_model WITH ( model_class = 'mock.MagicMock', target_column = 'target' ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 ) """ ) first_mock, _ = c.schema[c.schema_name].models["my_model"] with pytest.raises(RuntimeError): c.sql( """ CREATE MODEL my_model WITH ( model_class = 'mock.MagicMock', target_column = 'target' ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 ) """ ) c.sql( """ CREATE MODEL IF NOT EXISTS my_model WITH ( model_class = 'mock.MagicMock', target_column = 'target' ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 ) """ ) assert c.schema[c.schema_name].models["my_model"][0] == first_mock c.sql( """ CREATE OR REPLACE MODEL my_model WITH ( model_class = 'mock.MagicMock', target_column = 'target' ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 ) """ ) assert c.schema[c.schema_name].models["my_model"][0] != first_mock second_mock, _ = c.schema[c.schema_name].models["my_model"] c.sql("DROP MODEL my_model") c.sql( """ CREATE MODEL IF NOT EXISTS my_model WITH ( model_class = 'mock.MagicMock', target_column = 'target' ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 ) """ ) assert c.schema[c.schema_name].models["my_model"][0] != second_mock def test_drop_model(c): with pytest.raises(RuntimeError): c.sql("DROP MODEL my_model") c.sql("DROP MODEL IF EXISTS my_model") c.sql( """ CREATE MODEL IF NOT EXISTS my_model WITH ( model_class = 'mock.MagicMock', target_column = 'target' ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 ) """ ) c.sql("DROP MODEL IF EXISTS my_model") assert "my_model" not in c.schema[c.schema_name].models def test_describe_model(c): c.sql( """ CREATE MODEL ex_describe_model WITH ( model_class = 'GradientBoostingClassifier', wrap_predict = True, target_column = 'target' ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 ) """ ) model, training_columns = c.schema[c.schema_name].models["ex_describe_model"] expected_dict = model.get_params() expected_dict["training_columns"] = training_columns.tolist() # hack for converting model class into string expected_series = ( pd.DataFrame.from_dict(expected_dict, orient="index", columns=["Params"])[ "Params" ] .apply(lambda x: str(x)) .sort_index() ) actual_series = c.sql("DESCRIBE MODEL ex_describe_model") actual_series = actual_series["Params"].apply( lambda x: str(x), meta=actual_series["Params"] ) assert_eq(expected_series, actual_series) with pytest.raises(RuntimeError): c.sql("DESCRIBE MODEL undefined_model") def test_export_model(c, tmpdir): with pytest.raises(RuntimeError): c.sql( """EXPORT MODEL not_available_model with ( format ='pickle', location = '/tmp/model.pkl' )""" ) c.sql( """ CREATE MODEL IF NOT EXISTS my_model WITH ( model_class = 'GradientBoostingClassifier', target_column = 'target' ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 ) """ ) # Happy flow temporary_file = os.path.join(tmpdir, "pickle_model.pkl") c.sql( """EXPORT MODEL my_model with ( format ='pickle', location = '{}' )""".format( temporary_file ) ) assert ( pickle.load(open(str(temporary_file), "rb")).estimator.__class__.__name__ == "GradientBoostingClassifier" ) temporary_file = os.path.join(tmpdir, "model.joblib") c.sql( """EXPORT MODEL my_model with ( format ='joblib', location = '{}' )""".format( temporary_file ) ) assert ( joblib.load(str(temporary_file)).estimator.__class__.__name__ == "GradientBoostingClassifier" ) with pytest.raises(NotImplementedError): temporary_dir = os.path.join(tmpdir, "model.onnx") c.sql( """EXPORT MODEL my_model with ( format ='onnx', location = '{}' )""".format( temporary_dir ) ) def test_mlflow_export(c, tmpdir): # Test only when mlflow was installed mlflow = pytest.importorskip("mlflow", reason="mlflow not installed") c.sql( """ CREATE MODEL IF NOT EXISTS my_model WITH ( model_class = 'GradientBoostingClassifier', target_column = 'target' ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 ) """ ) temporary_dir = os.path.join(tmpdir, "mlflow") c.sql( """EXPORT MODEL my_model with ( format ='mlflow', location = '{}' )""".format( temporary_dir ) ) # for sklearn compatible model assert ( mlflow.sklearn.load_model(str(temporary_dir)).estimator.__class__.__name__ == "GradientBoostingClassifier" ) # test for non sklearn compatible model c.sql( """ CREATE MODEL IF NOT EXISTS non_sklearn_model WITH ( model_class = 'mock.MagicMock', target_column = 'target' ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 ) """ ) temporary_dir = os.path.join(tmpdir, "non_sklearn") with pytest.raises(NotImplementedError): c.sql( """EXPORT MODEL non_sklearn_model with ( format ='mlflow', location = '{}' )""".format( temporary_dir ) ) @pytest.mark.xfail( sys.platform == "darwin", reason="Intermittent socket errors on macOS", strict=False ) def test_mlflow_export_xgboost(c, client, tmpdir): # Test only when mlflow & xgboost was installed mlflow = pytest.importorskip("mlflow", reason="mlflow not installed") xgboost = pytest.importorskip("xgboost", reason="xgboost not installed") c.sql( """ CREATE MODEL IF NOT EXISTS my_model_xgboost WITH ( model_class = 'DaskXGBClassifier', target_column = 'target' ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 ) """ ) temporary_dir = os.path.join(tmpdir, "mlflow_xgboost") c.sql( """EXPORT MODEL my_model_xgboost with ( format = 'mlflow', location = '{}' )""".format( temporary_dir ) ) assert ( mlflow.sklearn.load_model(str(temporary_dir)).__class__.__name__ == "DaskXGBClassifier" ) def test_mlflow_export_lightgbm(c, tmpdir): # Test only when mlflow & lightgbm was installed mlflow = pytest.importorskip("mlflow", reason="mlflow not installed") lightgbm = pytest.importorskip("lightgbm", reason="lightgbm not installed") c.sql( """ CREATE MODEL IF NOT EXISTS my_model_lightgbm WITH ( model_class = 'LGBMClassifier', target_column = 'target' ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 ) """ ) temporary_dir = os.path.join(tmpdir, "mlflow_lightgbm") c.sql( """EXPORT MODEL my_model_lightgbm with ( format = 'mlflow', location = '{}' )""".format( temporary_dir ) ) assert ( mlflow.sklearn.load_model(str(temporary_dir)).__class__.__name__ == "LGBMClassifier" ) def test_ml_experiment(c, client): with pytest.raises( ValueError, match="Parameters must include a 'model_class' " "or 'automl_class' parameter.", ): c.sql( """ CREATE EXPERIMENT my_exp WITH ( experiment_class = 'GridSearchCV', tune_parameters = (n_estimators = ARRAY [16, 32, 2],learning_rate = ARRAY [0.1,0.01,0.001], max_depth = ARRAY [3,4,5,10]), target_column = 'target' ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 ) """ ) with pytest.raises( ValueError, match="Parameters must include a 'experiment_class' " "parameter for tuning GradientBoostingClassifier.", ): c.sql( """ CREATE EXPERIMENT my_exp WITH ( model_class = 'GradientBoostingClassifier', tune_parameters = (n_estimators = ARRAY [16, 32, 2],learning_rate = ARRAY [0.1,0.01,0.001], max_depth = ARRAY [3,4,5,10]), target_column = 'target' ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 ) """ ) with pytest.raises( ValueError, match="Can not import model that.is.not.a.python.class. Make sure you spelled " "it correctly and have installed all packages.", ): c.sql( """ CREATE EXPERIMENT IF NOT EXISTS my_exp WITH ( model_class = 'that.is.not.a.python.class', experiment_class = 'GridSearchCV', tune_parameters = (n_estimators = ARRAY [16, 32, 2],learning_rate = ARRAY [0.1,0.01,0.001], max_depth = ARRAY [3,4,5,10]), target_column = 'target' ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 ) """ ) with pytest.raises( ValueError, match="Can not import tuner that.is.not.a.python.class. Make sure you spelled " "it correctly and have installed all packages.", ): c.sql( """ CREATE EXPERIMENT IF NOT EXISTS my_exp WITH ( model_class = 'GradientBoostingClassifier', experiment_class = 'that.is.not.a.python.class', tune_parameters = (n_estimators = ARRAY [16, 32, 2],learning_rate = ARRAY [0.1,0.01,0.001], max_depth = ARRAY [3,4,5,10]), target_column = 'target' ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 ) """ ) with pytest.raises( ValueError, match="Can not import automl model that.is.not.a.python.class. " "Make sure you spelled " "it correctly and have installed all packages.", ): c.sql( """ CREATE EXPERIMENT my_exp64 WITH ( automl_class = 'that.is.not.a.python.class', automl_kwargs = ( population_size = 2, generations = 2, cv = 2, n_jobs = -1, use_dask = True, max_eval_time_mins = 1 ), target_column = 'target' ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 ) """ ) # happy flow c.sql( """ CREATE EXPERIMENT my_exp WITH ( model_class = 'GradientBoostingClassifier', experiment_class = 'GridSearchCV', tune_parameters = (n_estimators = ARRAY [16, 32, 2],learning_rate = ARRAY [0.1,0.01,0.001], max_depth = ARRAY [3,4,5,10]), experiment_kwargs = (n_jobs = -1), target_column = 'target' ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 ) """ ) assert "my_exp" in c.schema[c.schema_name].models, "Best model was not registered" check_trained_model(c, "my_exp") with pytest.raises(RuntimeError): # my_exp already exists c.sql( """ CREATE EXPERIMENT my_exp WITH ( model_class = 'GradientBoostingClassifier', experiment_class = 'GridSearchCV', tune_parameters = (n_estimators = ARRAY [16, 32, 2],learning_rate = ARRAY [0.1,0.01,0.001], max_depth = ARRAY [3,4,5,10]), target_column = 'target' ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 ) """ ) c.sql( """ CREATE EXPERIMENT IF NOT EXISTS my_exp WITH ( model_class = 'GradientBoostingClassifier', experiment_class = 'GridSearchCV', tune_parameters = (n_estimators = ARRAY [16, 32, 2],learning_rate = ARRAY [0.1,0.01,0.001], max_depth = ARRAY [3,4,5,10]), experiment_kwargs = (n_jobs = -1), target_column = 'target' ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 ) """ ) c.sql( """ CREATE OR REPLACE EXPERIMENT my_exp WITH ( model_class = 'GradientBoostingClassifier', experiment_class = 'GridSearchCV', tune_parameters = (n_estimators = ARRAY [16, 32, 2],learning_rate = ARRAY [0.1,0.01,0.001], max_depth = ARRAY [3,4,5,10]), experiment_kwargs = (n_jobs = -1), target_column = 'target' ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 ) """ ) with pytest.raises( ValueError, match="Unsupervised Algorithm cannot be tuned Automatically," "Consider providing 'target column'", ): c.sql( """ CREATE EXPERIMENT my_exp1 WITH ( model_class = 'KMeans', experiment_class = 'RandomizedSearchCV', tune_parameters = (n_clusters = ARRAY [3,4,16],tol = ARRAY [0.1,0.01,0.001], max_iter = ARRAY [3,4,5,10]) ) AS ( SELECT x, y FROM timeseries LIMIT 100 ) """ ) @pytest.mark.xfail( reason="tpot is broken with sklearn==1.4.0", condition=SKLEARN_EQ_140 ) def test_experiment_automl_classifier(c, client): tpot = pytest.importorskip("tpot", reason="tpot not installed") c.sql( """ CREATE EXPERIMENT my_automl_exp1 WITH ( automl_class = 'tpot.TPOTClassifier', automl_kwargs = (population_size=2, generations=2, cv=2, n_jobs=-1), target_column = 'target' ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 ) """ ) assert ( "my_automl_exp1" in c.schema[c.schema_name].models ), "Best model was not registered" check_trained_model(c, "my_automl_exp1") @pytest.mark.xfail( reason="tpot is broken with sklearn==1.4.0", condition=SKLEARN_EQ_140 ) def test_experiment_automl_regressor(c, client): tpot = pytest.importorskip("tpot", reason="tpot not installed") # test regressor c.sql( """ CREATE EXPERIMENT my_automl_exp2 WITH ( automl_class = 'tpot.TPOTRegressor', automl_kwargs = (population_size=2, generations=2, cv=2, n_jobs=-1, max_eval_time_mins=1), target_column = 'target' ) AS ( SELECT x, y, x*y AS target FROM timeseries LIMIT 100 ) """ ) assert ( "my_automl_exp2" in c.schema[c.schema_name].models ), "Best model was not registered" check_trained_model(c, "my_automl_exp2") def test_predict_with_nullable_types(c): df = pd.DataFrame( { "rough_day_of_year": [0, 1, 2, 3], "prev_day_inches_rained": [0.0, 1.0, 2.0, 3.0], "rained": [False, False, False, True], } ) c.create_table("train_set", df) model_class = "'LogisticRegression'" c.sql( f""" CREATE OR REPLACE MODEL model WITH ( model_class = {model_class}, wrap_predict = True, wrap_fit = False, target_column = 'rained' ) AS ( SELECT * FROM train_set ) """ ) expected = c.sql( """ SELECT * FROM PREDICT( MODEL model, SELECT * FROM train_set ) """ ) df = pd.DataFrame( { "rough_day_of_year": pd.Series([0, 1, 2, 3], dtype="Int32"), "prev_day_inches_rained": pd.Series([0.0, 1.0, 2.0, 3.0], dtype="Float32"), "rained": pd.Series([False, False, False, True]), } ) c.create_table("train_set", df) c.sql( f""" CREATE OR REPLACE MODEL model WITH ( model_class = {model_class}, wrap_predict = True, wrap_fit = False, target_column = 'rained' ) AS ( SELECT * FROM train_set ) """ ) result = c.sql( """ SELECT * FROM PREDICT( MODEL model, SELECT * FROM train_set ) """ ) assert_eq( expected, result, check_dtype=False, ) def test_predict_with_limit_offset(c): c.sql( """ CREATE MODEL my_model WITH ( model_class = 'GradientBoostingClassifier', wrap_predict = True, target_column = 'target' ) AS ( SELECT x, y, x*y > 0 AS target FROM timeseries LIMIT 100 ) """ ) res = c.sql( """ SELECT * FROM PREDICT ( MODEL my_model, SELECT x, y FROM timeseries LIMIT 100 OFFSET 100 ) """ ) res.compute() ================================================ FILE: tests/integration/test_over.py ================================================ import pandas as pd import pytest from tests.utils import assert_eq, skipif_dask_expr_enabled def test_over_with_sorting(c, user_table_1): return_df = c.sql( """ SELECT user_id, b, ROW_NUMBER() OVER (ORDER BY user_id, b) AS "R" FROM user_table_1 """ ) expected_df = user_table_1.sort_values(["user_id", "b"]) expected_df["R"] = [1, 2, 3, 4] assert_eq(return_df, expected_df, check_dtype=False, check_index=False) def test_over_with_partitioning(c, user_table_2): return_df = c.sql( """ SELECT user_id, c, ROW_NUMBER() OVER (PARTITION BY c) AS "R" FROM user_table_2 ORDER BY user_id, c """ ) expected_df = user_table_2.sort_values(["user_id", "c"]) expected_df["R"] = [1, 1, 1, 1] assert_eq(return_df, expected_df, check_dtype=False, check_index=False) def test_over_with_grouping_and_sort(c, user_table_1): return_df = c.sql( """ SELECT user_id, b, ROW_NUMBER() OVER (PARTITION BY user_id ORDER BY b) AS "R" FROM user_table_1 """ ) expected_df = user_table_1.sort_values(["user_id", "b"]) expected_df["R"] = [1, 1, 2, 1] assert_eq(return_df, expected_df, check_dtype=False, check_index=False) def test_over_with_different(c, user_table_1): return_df = c.sql( """ SELECT user_id, b, ROW_NUMBER() OVER (PARTITION BY user_id ORDER BY b) AS "R1", ROW_NUMBER() OVER (ORDER BY user_id, b) AS "R2" FROM user_table_1 """ ) expected_df = pd.DataFrame( { "user_id": user_table_1.user_id, "b": user_table_1.b, "R1": [2, 1, 1, 1], "R2": [3, 1, 2, 4], } ) assert_eq(return_df, expected_df, check_dtype=False, check_index=False) # TODO: investigate source of window count deadlocks @skipif_dask_expr_enabled("Deadlocks with query planning enabled") def test_over_calls(c, user_table_1): return_df = c.sql( """ SELECT user_id, b, ROW_NUMBER() OVER (PARTITION BY user_id ORDER BY b) AS "O1", FIRST_VALUE(user_id*10 - b) OVER (PARTITION BY user_id ORDER BY b) AS "O2", -- SINGLE_VALUE(user_id*10 - b) OVER (PARTITION BY user_id ORDER BY b) AS "O3", LAST_VALUE(user_id*10 - b) OVER (PARTITION BY user_id ORDER BY b ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "O4", SUM(user_id) OVER (PARTITION BY user_id ORDER BY b) AS "O5", AVG(user_id) OVER (PARTITION BY user_id ORDER BY b) AS "O6", COUNT(*) OVER (PARTITION BY user_id ORDER BY b) AS "O7", COUNT(b) OVER (PARTITION BY user_id ORDER BY b) AS "O7b", MAX(b) OVER (PARTITION BY user_id ORDER BY b) AS "O8", MIN(b) OVER (PARTITION BY user_id ORDER BY b) AS "O9" FROM user_table_1 """ ) expected_df = pd.DataFrame( { "user_id": user_table_1.user_id, "b": user_table_1.b, "O1": [2, 1, 1, 1], "O2": [19, 7, 19, 27], # "O3": [19, 7, 19, 27], https://github.com/dask-contrib/dask-sql/issues/651 "O4": [17, 7, 17, 27], "O5": [4, 1, 2, 3], "O6": [2, 1, 2, 3], "O7": [2, 1, 1, 1], "O7b": [2, 1, 1, 1], "O8": [3, 3, 1, 3], "O9": [1, 3, 1, 3], } ) assert_eq(return_df, expected_df, check_dtype=False, check_index=False) @pytest.mark.xfail( reason="Need to add single_value window function, see https://github.com/dask-contrib/dask-sql/issues/651" ) def test_over_single_value(c, user_table_1): return_df = c.sql( """ SELECT user_id, b, SINGLE_VALUE(user_id*10 - b) OVER (PARTITION BY user_id ORDER BY b) AS "O3", FROM user_table_1 """ ) expected_df = pd.DataFrame( { "user_id": user_table_1.user_id, "b": user_table_1.b, "O3": [19, 7, 19, 27], } ) assert_eq(return_df, expected_df, check_dtype=False, check_index=False) # TODO: investigate source of window count deadlocks @skipif_dask_expr_enabled("Deadlocks with query planning enabled") def test_over_with_windows(c): tmp_df = pd.DataFrame({"a": range(5)}) c.create_table("tmp", tmp_df) return_df = c.sql( """ SELECT a, SUM(a) OVER (ORDER BY a ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS "O1", SUM(a) OVER (ORDER BY a ROWS BETWEEN 2 PRECEDING AND 3 FOLLOWING) AS "O2", SUM(a) OVER (ORDER BY a ROWS BETWEEN 2 PRECEDING AND UNBOUNDED FOLLOWING) AS "O3", SUM(a) OVER (ORDER BY a ROWS BETWEEN CURRENT ROW AND 3 FOLLOWING) AS "O4", SUM(a) OVER (ORDER BY a ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) AS "O5", SUM(a) OVER (ORDER BY a ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS "O6", SUM(a) OVER (ORDER BY a ROWS BETWEEN UNBOUNDED PRECEDING AND 3 FOLLOWING) AS "O7", SUM(a) OVER (ORDER BY a ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "O8", SUM(a) OVER (ORDER BY a ROWS BETWEEN 3 FOLLOWING AND 3 FOLLOWING) AS "O9", COUNT(a) OVER (ORDER BY a ROWS BETWEEN 3 FOLLOWING AND 3 FOLLOWING) AS "O9a", SUM(a) OVER (ORDER BY a ROWS BETWEEN 3 PRECEDING AND 1 PRECEDING) AS "O10" FROM tmp """ ) expected_df = pd.DataFrame( { "a": return_df.a, "O1": [0, 1, 3, 6, 9], "O2": [6, 10, 10, 10, 9], "O3": [10, 10, 10, 10, 9], "O4": [6, 10, 9, 7, 4], "O5": [10, 10, 9, 7, 4], "O6": [0, 1, 3, 6, 10], "O7": [6, 10, 10, 10, 10], "O8": [10, 10, 10, 10, 10], "O9": [3, 4, None, None, None], "O9a": [1, 1, 0, 0, 0], "O10": [None, 0, 1, 3, 6], } ) assert_eq(return_df, expected_df, check_dtype=False, check_index=False) ================================================ FILE: tests/integration/test_postgres.py ================================================ import sys import pytest pytestmark = pytest.mark.xfail( condition=sys.platform in ("win32", "darwin"), reason="hive testing not supported on Windows/macOS", ) docker = pytest.importorskip("docker") sqlalchemy = pytest.importorskip("sqlalchemy") @pytest.fixture(scope="session") def engine(): client = docker.from_env() network = client.networks.create("dask-sql", driver="bridge") postgres = client.containers.run( "postgres:latest", detach=True, remove=True, network="dask-sql", environment={"POSTGRES_HOST_AUTH_METHOD": "trust"}, ) try: # Wait for it to start start_counter = 2 postgres.exec_run(["bash"]) for l in postgres.logs(stream=True): if b"database system is ready to accept connections" in l: start_counter -= 1 if start_counter == 0: break # get the address and create the connection postgres.reload() address = postgres.attrs["NetworkSettings"]["Networks"]["dask-sql"]["IPAddress"] port = 5432 engine = sqlalchemy.create_engine( f"postgresql+psycopg2://postgres@{address}:{port}/postgres" ) yield engine except Exception: postgres.kill() network.remove() raise postgres.kill() network.remove() @pytest.mark.xfail(reason="WIP DataFusion") def test_select(assert_query_gives_same_result): assert_query_gives_same_result( """ SELECT * FROM df1 """ ) assert_query_gives_same_result( """ SELECT df1.user_id + 5, 2 * df1.a + df1.b / df1.user_id - df1.b, df1.a IS NULL, df1.a IS NOT NULL, df1.b_bool IS TRUE, df1.b_bool IS NOT TRUE, df1.b_bool IS FALSE, df1.b_bool IS NOT FALSE, df1.b_bool IS UNKNOWN, df1.b_bool IS NOT UNKNOWN, ABS(df1.a), ACOS(df1.a), ASIN(df1.a), ATAN(df1.a), ATAN2(df1.a, df1.b), CBRT(df1.a), CEIL(df1.a), COS(df1.a), COT(df1.a), DEGREES(df1.a), EXP(df1.a), FLOOR(df1.a), LOG10(df1.a), LN(df1.a), POWER(df1.a, 3), POWER(df1.a, -3), POWER(df1.a, 1.1), RADIANS(df1.a), ROUND(df1.a), SIGN(df1.a), SIN(df1.a), TAN(df1.a) FROM df1 """ ) assert_query_gives_same_result( """ SELECT df2.user_id, df2.d FROM df2 """ ) assert_query_gives_same_result( """ SELECT 1 AS I, -5.34344 AS F, 'öäll' AS S """ ) assert_query_gives_same_result( """ SELECT CASE WHEN user_id <> 3 THEN 4 ELSE 2 END FROM df2 """ ) def test_join(assert_query_gives_same_result): assert_query_gives_same_result( """ SELECT df1.user_id, df1.a, df1.b, df2.user_id AS user_id_2, df2.c, df2.d FROM df1 JOIN df2 ON df1.user_id = df2.user_id """, ["user_id", "a", "b", "user_id_2", "c", "d"], ) def test_sort(assert_query_gives_same_result): assert_query_gives_same_result( """ SELECT user_id, b FROM df1 ORDER BY b NULLS FIRST, user_id DESC NULLS FIRST """ ) assert_query_gives_same_result( """ SELECT c, d FROM df2 ORDER BY c NULLS FIRST, d NULLS FIRST, user_id NULLS FIRST """ ) def test_limit(assert_query_gives_same_result): assert_query_gives_same_result( """ SELECT c, d FROM df2 ORDER BY c NULLS FIRST, d NULLS FIRST, user_id NULLS FIRST LIMIT 10 OFFSET 20 """ ) assert_query_gives_same_result( """ SELECT c, d FROM df2 ORDER BY c NULLS FIRST, d NULLS FIRST, user_id NULLS FIRST LIMIT 200 """ ) @pytest.mark.xfail(reason="WIP DataFusion") def test_groupby(assert_query_gives_same_result): assert_query_gives_same_result( """ SELECT d, SUM(1.0 * c), AVG(1.0 * user_id) FROM df2 WHERE d IS NOT NULL -- dask behaves differently on NaNs in groupbys GROUP BY d ORDER BY SUM(c) LIMIT 10 """ ) def test_filter(assert_query_gives_same_result): assert_query_gives_same_result( """ SELECT a FROM df1 WHERE user_id = 3 AND a > 0.5 """ ) assert_query_gives_same_result( """ SELECT d FROM df2 WHERE d NOT LIKE '%%c' """ ) assert_query_gives_same_result( """ SELECT d FROM df2 WHERE (d NOT LIKE '%%c') IS NULL """ ) def test_string_operations(assert_query_gives_same_result): assert_query_gives_same_result( """ SELECT s, s || 'hello' || s, s SIMILAR TO '%%(b|d)%%', s SIMILAR TO '%%(B|c)%%', s SIMILAR TO '%%[a-zA-Z]%%', s SIMILAR TO '.*', s NOT SIMILAR TO '.*', s LIKE '%%(b|d)%%', s LIKE '%%(B|c)%%', s LIKE '%%[a-zA-Z]%%', s LIKE '.*', S NOT LIKE '.*', s ILIKE '%%(b|d)%%', s ILIKE '%%(B|c)%%', s NOT ILIKE '%%(b|d)%%', s NOT ILIKE '%%(B|c)%%', CHAR_LENGTH(s), UPPER(s), LOWER(s), TRIM('a' FROM s), TRIM(BOTH 'a' FROM s), TRIM(LEADING 'a' FROM s), TRIM(TRAILING 'a' FROM s), SUBSTRING(s FROM -1), SUBSTRING(s FROM 10), SUBSTRING(s FROM 2), SUBSTRING(s FROM 2 FOR 2), SUBSTR(s,2,2) as s2, INITCAP(s), INITCAP(UPPER(s)), INITCAP(LOWER(s)) FROM df3 """ ) @pytest.mark.xfail(reason="POSITION syntax not supported by parser") def test_string_position(assert_query_gives_same_result): assert_query_gives_same_result( """ SELECT POSITION('a' IN s), POSITION('ZL' IN s) FROM df3 """ ) @pytest.mark.xfail(reason="OVERLAY syntax not supported by parser") def test_string_overlay(assert_query_gives_same_result): assert_query_gives_same_result( """ SELECT OVERLAY(s PLACING 'XXX' FROM 2), OVERLAY(s PLACING 'XXX' FROM 2 FOR 4), OVERLAY(s PLACING 'XXX' FROM 2 FOR 1) FROM df3 """ ) @pytest.mark.xfail(reason="WIP DataFusion") def test_statistical_functions(assert_query_gives_same_result): # test regr_count assert_query_gives_same_result( """ select user_id, REGR_COUNT(a,b) FROM df1 GROUP BY user_id """, ["user_id"], check_names=False, ) assert_query_gives_same_result( """ select user_id, REGR_SXX(a, 1.0 * b) FROM df1 GROUP BY user_id """, ["user_id"], check_names=False, ) assert_query_gives_same_result( """ select user_id, REGR_SYY(a, 1.0 * b) FROM df1 GROUP BY user_id """, ["user_id"], check_names=False, ) assert_query_gives_same_result( """ select user_id, COVAR_POP(a, b) FROM df1 GROUP BY user_id """, ["user_id"], check_names=False, ) assert_query_gives_same_result( """ select user_id,COVAR_SAMP(a,b) FROM df1 GROUP BY user_id """, ["user_id"], check_names=False, ) ================================================ FILE: tests/integration/test_rex.py ================================================ from datetime import datetime import dask.dataframe as dd import numpy as np import pandas as pd import pytest from tests.utils import assert_eq def test_year(c, datetime_table): result_df = c.sql( """ SELECT year(timezone) from datetime_table """ ) assert result_df.shape[0].compute() == datetime_table.shape[0] assert result_df.compute().iloc[0][0] == 2014 def test_case(c, df): result_df = c.sql( """ SELECT (CASE WHEN a = 3 THEN 1 END) AS "S1", (CASE WHEN a > 0 THEN a ELSE 1 END) AS "S2", (CASE WHEN a = 4 THEN 3 ELSE a + 1 END) AS "S3", (CASE WHEN a = 3 THEN 1 WHEN a > 0 THEN 2 ELSE a END) AS "S4", CASE WHEN (a >= 1 AND a < 2) OR (a > 2) THEN CAST('in-between' AS VARCHAR) ELSE CAST('out-of-range' AS VARCHAR) END AS "S5", CASE WHEN (a < 2) OR (3 < a AND a < 4) THEN 42 ELSE 47 END AS "S6", CASE WHEN (1 < a AND a <= 4) THEN 1 ELSE 0 END AS "S7", CASE a WHEN 2 THEN 5 ELSE a + 1 END AS "S8" FROM df """ ) expected_df = pd.DataFrame(index=df.index) expected_df["S1"] = df.a.apply(lambda a: 1 if a == 3 else np.NaN) expected_df["S2"] = df.a.apply(lambda a: a if a > 0 else 1) expected_df["S3"] = df.a.apply(lambda a: 3 if a == 4 else a + 1).astype("Int64") expected_df["S4"] = df.a.apply(lambda a: 1 if a == 3 else 2 if a > 0 else a).astype( "Int64" ) expected_df["S5"] = df.a.apply( lambda a: "in-between" if ((1 <= a < 2) or (a > 2)) else "out-of-range" ) expected_df["S6"] = df.a.apply(lambda a: 42 if ((a < 2) or (3 < a < 4)) else 47) expected_df["S7"] = df.a.apply(lambda a: 1 if (1 < a <= 4) else 0) expected_df["S8"] = df.a.apply(lambda a: 5 if a == 2 else a + 1).astype("Int64") assert_eq(result_df, expected_df) def test_intervals(c): df = c.sql( """SELECT INTERVAL '3' DAY as "IN" """ ) expected_df = pd.DataFrame( { "IN": [pd.to_timedelta("3d")], } ) assert_eq(df, expected_df) date1 = datetime(2021, 10, 3, 15, 53, 42, 47) date2 = datetime(2021, 2, 28, 15, 53, 42, 47) dates = dd.from_pandas(pd.DataFrame({"d": [date1, date2]}), npartitions=1) c.create_table("dates", dates) df = c.sql( """SELECT d + INTERVAL '5 days' AS "Plus_5_days" FROM dates """ ) expected_df = pd.DataFrame( { "Plus_5_days": [ datetime(2021, 10, 8, 15, 53, 42, 47), datetime(2021, 3, 5, 15, 53, 42, 47), ] } ) assert_eq(df, expected_df) def test_literals(c): df = c.sql( """SELECT 'a string äö' AS "S", 4.4 AS "F", -4564347464 AS "I", TIME '08:08:00.091' AS "T", TIMESTAMP '2022-04-06 17:33:21' AS "DT", DATE '1991-06-02' AS "D", INTERVAL '1' DAY AS "IN" """ ) expected_df = pd.DataFrame( { "S": ["a string äö"], "F": [4.4], "I": [-4564347464], "T": [pd.to_datetime("1970-01-01 08:08:00.091")], "DT": [pd.to_datetime("2022-04-06 17:33:21")], "D": [pd.to_datetime("1991-06-02 00:00")], "IN": [pd.to_timedelta("1d")], } ) assert_eq(df, expected_df) def test_date_interval_math(c): df = c.sql( """SELECT DATE '1998-08-18' - INTERVAL '4 days' AS "before", DATE '1998-08-18' + INTERVAL '4 days' AS "after" """ ) expected_df = pd.DataFrame( { "before": [pd.to_datetime("1998-08-14 00:00")], "after": [pd.to_datetime("1998-08-22 00:00")], } ) assert_eq(df, expected_df) def test_literal_null(c): df = c.sql( """ SELECT NULL AS "N", 1 + NULL AS "I" """ ) expected_df = pd.DataFrame({"N": [pd.NA], "I": [pd.NA]}) expected_df["I"] = expected_df["I"].astype("Int64") assert_eq(df, expected_df) def test_random(c): query_with_seed = """ SELECT RAND(0) AS "0", RAND_INTEGER(0, 10) AS "1" """ result_df = c.sql(query_with_seed) # assert that repeated queries give the same result assert_eq(result_df, c.sql(query_with_seed)) # assert output result_df = result_df.compute() assert result_df["0"].dtype == "float64" assert result_df["1"].dtype == "Int64" assert 0 <= result_df["0"][0] < 1 assert 0 <= result_df["1"][0] < 10 query_wo_seed = """ SELECT RAND() AS "0", RANDOM() AS "1", RAND_INTEGER(30) AS "2" """ result_df = c.sql(query_wo_seed) result_df = result_df.compute() # assert output types assert result_df["0"].dtype == "float64" assert result_df["1"].dtype == "float64" assert result_df["2"].dtype == "Int64" assert 0 <= result_df["0"][0] < 1 assert 0 <= result_df["1"][0] < 1 assert 0 <= result_df["2"][0] < 30 @pytest.mark.parametrize( "input_table", [ "string_table", pytest.param("gpu_string_table", marks=pytest.mark.gpu), ], ) def test_not(c, input_table, request): string_table = request.getfixturevalue(input_table) df = c.sql( f""" SELECT * FROM {input_table} WHERE NOT a LIKE '%normal%' """ ) expected_df = string_table[~string_table.a.str.contains("normal")] assert_eq(df, expected_df) def test_operators(c, df): result_df = c.sql( """ SELECT a * b AS m, -a AS u, a / b AS q, a + b AS s, a - b AS d, a = b AS e, a > b AS g, a >= b AS ge, a < b AS l, a <= b AS le, a <> b AS n FROM df """ ) expected_df = pd.DataFrame(index=df.index) expected_df["m"] = df["a"] * df["b"] expected_df["u"] = -df["a"] expected_df["q"] = df["a"] / df["b"] expected_df["s"] = df["a"] + df["b"] expected_df["d"] = df["a"] - df["b"] expected_df["e"] = df["a"] == df["b"] expected_df["g"] = df["a"] > df["b"] expected_df["ge"] = df["a"] >= df["b"] expected_df["l"] = df["a"] < df["b"] expected_df["le"] = df["a"] <= df["b"] expected_df["n"] = df["a"] != df["b"] assert_eq(result_df, expected_df) @pytest.mark.parametrize( "input_table,gpu", [ ("string_table", False), pytest.param( "gpu_string_table", True, marks=( pytest.mark.gpu, pytest.mark.xfail( reason="Failing due to cuDF bug https://github.com/rapidsai/cudf/issues/9434" ), ), ), ], ) def test_like(c, input_table, gpu, request): string_table = request.getfixturevalue(input_table) df = c.sql( f""" SELECT * FROM {input_table} WHERE a SIMILAR TO '%n[a-z]rmal st_i%' """ ) assert_eq(df, string_table.iloc[[0, 3]]) df = c.sql( f""" SELECT * FROM {input_table} WHERE a NOT SIMILAR TO '%n[a-z]rmal st_i%' """ ) assert_eq(df, string_table.iloc[[1, 2]]) df = c.sql( f""" SELECT * FROM {input_table} WHERE a LIKE '%n[a-z]rmal st_i%' """ ) assert len(df) == 0 df = c.sql( f""" SELECT * FROM {input_table} WHERE a NOT LIKE '%n[a-z]rmal st_i%' """ ) assert_eq(df, string_table) df = c.sql( f""" SELECT * FROM {input_table} WHERE a LIKE '%a Normal String%' """ ) assert len(df) == 0 df = c.sql( f""" SELECT * FROM {input_table} WHERE a ILIKE '%a Normal String%' """ ) assert_eq(df, string_table.iloc[[0, 3]]) df = c.sql( f""" SELECT * FROM {input_table} WHERE a NOT ILIKE '%a Normal String%' """ ) assert_eq(df, string_table.iloc[[1, 2]]) # TODO: uncomment when sqlparser adds parsing support for non-standard escape characters # https://github.com/dask-contrib/dask-sql/issues/754 # df = c.sql( # f""" # SELECT * FROM {input_table} # WHERE a LIKE 'Ä%Ä_Ä%' ESCAPE 'Ä' # """ # ) # assert_eq(df, string_table.iloc[[1]]) df = c.sql( f""" SELECT * FROM {input_table} WHERE a SIMILAR TO '^|()-*r[r]$' ESCAPE 'r' """ ) assert_eq(df, string_table.iloc[[2, 3]]) df = c.sql( f""" SELECT * FROM {input_table} WHERE a LIKE '^|()-*r[r]$' ESCAPE 'r' """ ) assert_eq(df, string_table.iloc[[2]]) df = c.sql( f""" SELECT * FROM {input_table} WHERE a LIKE '%_' ESCAPE 'r' """ ) assert_eq(df, string_table) string_table2 = pd.DataFrame({"b": ["a", "b", None, pd.NA, float("nan")]}) c.create_table("string_table2", string_table2, gpu=gpu) df = c.sql( """ SELECT * FROM string_table2 WHERE b LIKE 'b' """ ) assert_eq(df, string_table2.iloc[[1]]) def test_null(c): df = c.sql( """ SELECT c IS NOT NULL AS nn, c IS NULL AS n FROM user_table_nan """ ) expected_df = pd.DataFrame(index=[0, 1, 2]) expected_df["nn"] = [True, False, True] expected_df["nn"] = expected_df["nn"].astype("boolean") expected_df["n"] = [False, True, False] assert_eq(df, expected_df) df = c.sql( """ SELECT a IS NOT NULL AS nn, a IS NULL AS n FROM string_table """ ) expected_df = pd.DataFrame(index=[0, 1, 2, 3]) expected_df["nn"] = [True, True, True, True] expected_df["nn"] = expected_df["nn"].astype("boolean") expected_df["n"] = [False, False, False, False] assert_eq(df, expected_df) @pytest.mark.filterwarnings( "ignore:divide by zero:RuntimeWarning:dask_sql.physical.rex.core.call" ) @pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)]) def test_coalesce(c, gpu): df = dd.from_pandas( pd.DataFrame({"a": [1, 2, 3], "b": [np.nan] * 3}), npartitions=1 ) c.create_table("df", df, gpu=gpu) df = c.sql( """ SELECT COALESCE(3, 5) as c1, COALESCE(NULL, NULL) as c2, COALESCE(NULL, 'hi') as c3, COALESCE(NULL, NULL, 'bye', 5/0) as c4, COALESCE(NULL, 3/2, NULL, 'fly') as c5, COALESCE(NULL, MEAN(b), MEAN(a), 4/0) as c6 FROM df """ ) expected_df = pd.DataFrame( { "c1": [3], "c2": [pd.NA], "c3": ["hi"], "c4": ["bye"], "c5": ["1.5"], "c6": [2.0], } ) expected_df["c2"] = expected_df["c2"].astype("Int8") assert_eq(df, expected_df, check_dtype=False) df = c.sql( """ SELECT COALESCE(a, b) as c1, COALESCE(b, a) as c2, COALESCE(a, a) as c3, COALESCE(b, b) as c4 FROM df """ ) expected_df = pd.DataFrame( { "c1": [1, 2, 3], "c2": [1, 2, 3], "c3": [1, 2, 3], "c4": [np.nan] * 3, } ) assert_eq(df, expected_df, check_dtype=False) c.drop_table("df") def test_boolean_operations(c): df = dd.from_pandas(pd.DataFrame({"b": [1, 0, -1]}), npartitions=1) df["b"] = df["b"].apply( lambda x: pd.NA if x < 0 else x > 0, meta=("b", "bool") ) # turn into a bool column c.create_table("df", df) result_df = c.sql( """ SELECT b IS TRUE AS t, b IS FALSE AS f, b IS NOT TRUE AS nt, b IS NOT FALSE AS nf, b IS UNKNOWN AS u, b IS NOT UNKNOWN AS nu FROM df""" ) expected_df = pd.DataFrame( { "t": df.b.astype("boolean").fillna(False), "f": ~df.b.astype("boolean").fillna(True), "nt": ~df.b.astype("boolean").fillna(False), "nf": df.b.astype("boolean").fillna(True), "u": df.b.isna(), "nu": ~df.b.isna().astype("boolean"), }, ) assert_eq(result_df, expected_df, check_dtype=False) def test_math_operations(c, df): result_df = c.sql( """ SELECT ABS(b) AS "abs" , ACOS(b) AS "acos" , ASIN(b) AS "asin" , ATAN(b) AS "atan" , ATAN2(a, b) AS "atan2" , CBRT(b) AS "cbrt" , CEIL(b) AS "ceil" , COS(b) AS "cos" , COT(b) AS "cot" , DEGREES(b) AS "degrees" , EXP(b) AS "exp" , FLOOR(b) AS "floor" , LOG10(b) AS "log10" , LN(b) AS "ln" , MOD(b, 4) AS "mod" , POWER(b, 2) AS "power" , POWER(b, a) AS "power2" , RADIANS(b) AS "radians" , ROUND(b) AS "round" , ROUND(b, 3) AS "round2" , SIGN(b) AS "sign" , SIN(b) AS "sin" , TAN(b) AS "tan" , TRUNCATE(b) AS "truncate" FROM df """ ) expected_df = pd.DataFrame(index=df.index) expected_df["abs"] = df.b.abs() expected_df["acos"] = np.arccos(df.b) expected_df["asin"] = np.arcsin(df.b) expected_df["atan"] = np.arctan(df.b) expected_df["atan2"] = np.arctan2(df.a, df.b) expected_df["cbrt"] = np.cbrt(df.b) expected_df["ceil"] = np.ceil(df.b) expected_df["cos"] = np.cos(df.b) expected_df["cot"] = 1 / np.tan(df.b) expected_df["degrees"] = df.b / np.pi * 180 expected_df["exp"] = np.exp(df.b) expected_df["floor"] = np.floor(df.b) expected_df["log10"] = np.log10(df.b) expected_df["ln"] = np.log(df.b) expected_df["mod"] = np.mod(df.b, 4) expected_df["power"] = np.power(df.b, 2) expected_df["power2"] = np.power(df.b, df.a) expected_df["radians"] = df.b / 180 * np.pi expected_df["round"] = np.round(df.b) expected_df["round2"] = np.round(df.b, 3) expected_df["sign"] = np.sign(df.b) expected_df["sin"] = np.sin(df.b) expected_df["tan"] = np.tan(df.b) expected_df["truncate"] = np.trunc(df.b) assert_eq(result_df, expected_df) def test_integer_div(c, df_simple): df = c.sql( """ SELECT 1 / a AS a, a / 2 AS b, 1.0 / a AS c FROM df_simple """ ) expected_df = pd.DataFrame( { "a": (1 // df_simple.a).astype("Int64"), "b": (df_simple.a // 2).astype("Int64"), "c": 1 / df_simple.a, } ) assert_eq(df, expected_df) @pytest.mark.xfail(reason="Subquery expressions not yet enabled") def test_subqueries(c, user_table_1, user_table_2): df = c.sql( """ SELECT * FROM user_table_2 WHERE EXISTS( SELECT * FROM user_table_1 WHERE user_table_1.b = user_table_2.c ) """ ) assert_eq(df, user_table_2[user_table_2.c.isin(user_table_1.b)], check_index=False) @pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)]) def test_string_functions(c, gpu): if gpu: input_table = "gpu_string_table" else: input_table = "string_table" df = c.sql( f""" SELECT a || 'hello' || a AS a, CONCAT(a, 'hello', a) as b, CHAR_LENGTH(a) AS c, UPPER(a) AS d, LOWER(a) AS e, -- POSITION('a' IN a FROM 4) AS f, -- POSITION('ZL' IN a) AS g, TRIM('a' FROM a) AS h, TRIM(BOTH 'a' FROM a) AS i, TRIM(LEADING 'a' FROM a) AS j, TRIM(TRAILING 'a' FROM a) AS k, -- OVERLAY(a PLACING 'XXX' FROM -1) AS l, -- OVERLAY(a PLACING 'XXX' FROM 2 FOR 4) AS m, -- OVERLAY(a PLACING 'XXX' FROM 2 FOR 1) AS n, SUBSTRING(a FROM -1) AS o, SUBSTRING(a FROM 10) AS p, SUBSTRING(a FROM 2) AS q, SUBSTRING(a FROM 2 FOR 2) AS r, SUBSTR(a, 3, 6) AS s, INITCAP(a) AS t, INITCAP(UPPER(a)) AS u, INITCAP(LOWER(a)) AS v, REPLACE(a, 'r', 'l') as w, REPLACE('Another String', 'th', 'b') as x FROM {input_table} """ ) if gpu: df = df.astype({"c": "int64"}) # , "f": "int64", "g": "int64"}) expected_df = pd.DataFrame( { "a": ["a normal stringhelloa normal string"], "b": ["a normal stringhelloa normal string"], "c": [15], "d": ["A NORMAL STRING"], "e": ["a normal string"], # "f": [7], # position from syntax not supported # "g": [0], "h": [" normal string"], "i": [" normal string"], "j": [" normal string"], "k": ["a normal string"], # "l": ["XXXormal string"], # overlay from syntax not supported by parser # "m": ["aXXXmal string"], # "n": ["aXXXnormal string"], "o": ["a normal string"], "p": ["string"], "q": [" normal string"], "r": [" n"], "s": ["normal"], "t": ["A Normal String"], "u": ["A Normal String"], "v": ["A Normal String"], "w": ["a nolmal stling"], "x": ["Anober String"], } ) assert_eq( df.head(1), expected_df, ) @pytest.mark.xfail(reason="POSITION syntax not supported by parser") @pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)]) def test_string_position(c, gpu): if gpu: input_table = "gpu_string_table" else: input_table = "string_table" df = c.sql( f""" SELECT POSITION('a' IN a FROM 4) AS f, POSITION('ZL' IN a) AS g, FROM {input_table} """ ) if gpu: df = df.astype({"f": "int64", "g": "int64"}) expected_df = pd.DataFrame( { "f": [7], "g": [0], } ) assert_eq( df.head(1), expected_df, ) @pytest.mark.xfail(reason="OVERLAY syntax not supported by parser") @pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)]) def test_string_overlay(c, gpu): if gpu: input_table = "gpu_string_table" else: input_table = "string_table" df = c.sql( f""" SELECT OVERLAY(a PLACING 'XXX' FROM -1) AS l, OVERLAY(a PLACING 'XXX' FROM 2 FOR 4) AS m, OVERLAY(a PLACING 'XXX' FROM 2 FOR 1) AS n, FROM {input_table} """ ) if gpu: df = df.astype({"c": "int64"}) # , "f": "int64", "g": "int64"}) expected_df = pd.DataFrame( { "l": ["XXXormal string"], "m": ["aXXXmal string"], "n": ["aXXXnormal string"], } ) assert_eq( df.head(1), expected_df, ) def test_date_functions(c): date = datetime(2021, 10, 3, 15, 53, 42, 47) df = dd.from_pandas(pd.DataFrame({"d": [date]}), npartitions=1) c.create_table("df", df) df = c.sql( """ SELECT EXTRACT(CENTURY FROM d) AS "century", EXTRACT(DAY FROM d) AS "day", EXTRACT(DECADE FROM d) AS "decade", EXTRACT(DOW FROM d) AS "dow", EXTRACT(DOY FROM d) AS "doy", EXTRACT(HOUR FROM d) AS "hour", EXTRACT(MICROSECONDS FROM d) AS "microsecond", EXTRACT(MILLENNIUM FROM d) AS "millennium", EXTRACT(MILLISECONDS FROM d) AS "millisecond", EXTRACT(MINUTE FROM d) AS "minute", EXTRACT(MONTH FROM d) AS "month", EXTRACT(QUARTER FROM d) AS "quarter", EXTRACT(SECOND FROM d) AS "second", EXTRACT(WEEK FROM d) AS "week", EXTRACT(YEAR FROM d) AS "year", EXTRACT(DATE FROM d) AS "date", LAST_DAY(d) as "last_day", TIMESTAMPADD(YEAR, 1, d) as "plus_1_year", TIMESTAMPADD(MONTH, 1, d) as "plus_1_month", TIMESTAMPADD(WEEK, 1, d) as "plus_1_week", TIMESTAMPADD(DAY, 1, d) as "plus_1_day", TIMESTAMPADD(HOUR, 1, d) as "plus_1_hour", TIMESTAMPADD(MINUTE, 1, d) as "plus_1_min", TIMESTAMPADD(SECOND, 1, d) as "plus_1_sec", TIMESTAMPADD(MICROSECOND, 999*1000, d) as "plus_999_millisec", TIMESTAMPADD(MICROSECOND, 999, d) as "plus_999_microsec", TIMESTAMPADD(QUARTER, 1, d) as "plus_1_qt", CEIL(d TO DAY) as ceil_to_day, CEIL(d TO HOUR) as ceil_to_hour, CEIL(d TO MINUTE) as ceil_to_minute, CEIL(d TO SECOND) as ceil_to_seconds, CEIL(d TO MILLISECOND) as ceil_to_millisec, FLOOR(d TO DAY) as floor_to_day, FLOOR(d TO HOUR) as floor_to_hour, FLOOR(d TO MINUTE) as floor_to_minute, FLOOR(d TO SECOND) as floor_to_seconds, FLOOR(d TO MILLISECOND) as floor_to_millisec FROM df """ ) expected_df = pd.DataFrame( { "century": [20], "day": [3], "decade": [202], "dow": [0], "doy": [276], "hour": [15], "microsecond": [47], "millennium": [2], "millisecond": [47000], "minute": [53], "month": [10], "quarter": [4], "second": [42], "week": [39], "year": [2021], "date": [datetime(2021, 10, 3)], "last_day": [datetime(2021, 10, 31, 15, 53, 42, 47)], "plus_1_year": [datetime(2022, 10, 3, 15, 53, 42, 47)], "plus_1_month": [datetime(2021, 11, 3, 15, 53, 42, 47)], "plus_1_week": [datetime(2021, 10, 10, 15, 53, 42, 47)], "plus_1_day": [datetime(2021, 10, 4, 15, 53, 42, 47)], "plus_1_hour": [datetime(2021, 10, 3, 16, 53, 42, 47)], "plus_1_min": [datetime(2021, 10, 3, 15, 54, 42, 47)], "plus_1_sec": [datetime(2021, 10, 3, 15, 53, 43, 47)], "plus_999_millisec": [datetime(2021, 10, 3, 15, 53, 42, 1000 * 999 + 47)], "plus_999_microsec": [datetime(2021, 10, 3, 15, 53, 42, 1046)], "plus_1_qt": [datetime(2022, 1, 3, 15, 53, 42, 47)], "ceil_to_day": [datetime(2021, 10, 4)], "ceil_to_hour": [datetime(2021, 10, 3, 16)], "ceil_to_minute": [datetime(2021, 10, 3, 15, 54)], "ceil_to_seconds": [datetime(2021, 10, 3, 15, 53, 43)], "ceil_to_millisec": [datetime(2021, 10, 3, 15, 53, 42, 1000)], "floor_to_day": [datetime(2021, 10, 3)], "floor_to_hour": [datetime(2021, 10, 3, 15)], "floor_to_minute": [datetime(2021, 10, 3, 15, 53)], "floor_to_seconds": [datetime(2021, 10, 3, 15, 53, 42)], "floor_to_millisec": [datetime(2021, 10, 3, 15, 53, 42)], } ) assert_eq(df, expected_df, check_dtype=False) # test exception handling with pytest.raises(NotImplementedError): df = c.sql( """ SELECT FLOOR(d TO YEAR) as floor_to_year FROM df """ ) def test_timestampdiff(c): ts_literal1 = datetime(2002, 3, 7, 9, 10, 5, 123) ts_literal2 = datetime(2001, 6, 5, 10, 11, 6, 234) df = dd.from_pandas( pd.DataFrame({"ts_literal1": [ts_literal1], "ts_literal2": [ts_literal2]}), npartitions=1, ) c.create_table("df", df) query = """ SELECT timestampdiff(NANOSECOND, ts_literal1, ts_literal2) as res0, timestampdiff(MICROSECOND, ts_literal1, ts_literal2) as res1, timestampdiff(SECOND, ts_literal1, ts_literal2) as res2, timestampdiff(MINUTE, ts_literal1, ts_literal2) as res3, timestampdiff(HOUR, ts_literal1, ts_literal2) as res4, timestampdiff(DAY, ts_literal1, ts_literal2) as res5, timestampdiff(WEEK, ts_literal1, ts_literal2) as res6, timestampdiff(MONTH, ts_literal1, ts_literal2) as res7, timestampdiff(QUARTER, ts_literal1, ts_literal2) as res8, timestampdiff(YEAR, ts_literal1, ts_literal2) as res9 FROM df """ df = c.sql(query) expected_df = pd.DataFrame( { "res0": [-23756338999889000], "res1": [-23756338999889], "res2": [-23756338], "res3": [-395938], "res4": [-6598], "res5": [-274], "res6": [-39], "res7": [-9], "res8": [-3], "res9": [0], } ) assert_eq(df, expected_df, check_dtype=False) test = pd.DataFrame( { "a": [ datetime(2002, 6, 5, 2, 1, 5, 200), datetime(2002, 9, 1), datetime(1970, 12, 3), ], "b": [ datetime(2002, 6, 7, 1, 0, 2, 100), datetime(2003, 6, 5), datetime(2038, 6, 5), ], } ) c.create_table("test", test) query = ( "SELECT timestampdiff(NANOSECOND, a, b) as nanoseconds," "timestampdiff(MICROSECOND, a, b) as microseconds," "timestampdiff(SECOND, a, b) as seconds," "timestampdiff(MINUTE, a, b) as minutes," "timestampdiff(HOUR, a, b) as hours," "timestampdiff(DAY, a, b) as days," "timestampdiff(WEEK, a, b) as weeks," "timestampdiff(MONTH, a, b) as months," "timestampdiff(QUARTER, a, b) as quarters," "timestampdiff(YEAR, a, b) as years" " FROM test" ) ddf = c.sql(query) expected_df = pd.DataFrame( { "nanoseconds": [ 169136999900000, 23932800000000000, 2130278400000000000, ], "microseconds": [169136999900, 23932800000000, 2130278400000000], "seconds": [169136, 23932800, 2130278400], "minutes": [2818, 398880, 35504640], "hours": [46, 6648, 591744], "days": [1, 277, 24656], "weeks": [0, 39, 3522], "months": [0, 9, 810], "quarters": [0, 3, 270], "years": [0, 0, 67], } ) assert_eq(ddf, expected_df, check_dtype=False) @pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)]) def test_totimestamp(c, gpu): df = pd.DataFrame( { "a": np.array([1203073300, 1406073600, 2806073600]), } ) c.create_table("df", df, gpu=gpu) df = c.sql( """ SELECT to_timestamp(a) AS date FROM df """ ) expected_df = pd.DataFrame( { "date": [ datetime(2008, 2, 15, 11, 1, 40), datetime(2014, 7, 23), datetime(2058, 12, 2, 16, 53, 20), ], } ) assert_eq(df, expected_df, check_dtype=False) df = pd.DataFrame( { "a": np.array(["1997-02-28 10:30:00", "1997-03-28 10:30:01"]), } ) c.create_table("df", df, gpu=gpu) df = c.sql( """ SELECT to_timestamp(a) AS date FROM df """ ) expected_df = pd.DataFrame( { "date": [ datetime(1997, 2, 28, 10, 30, 0), datetime(1997, 3, 28, 10, 30, 1), ], } ) assert_eq(df, expected_df, check_dtype=False) df = pd.DataFrame( { "a": np.array(["02/28/1997", "03/28/1997"]), } ) c.create_table("df", df, gpu=gpu) df = c.sql( """ SELECT to_timestamp(a, "%m/%d/%Y") AS date FROM df """ ) expected_df = pd.DataFrame( { "date": [ datetime(1997, 2, 28, 0, 0, 0), datetime(1997, 3, 28, 0, 0, 0), ], } ) # https://github.com/rapidsai/cudf/issues/12062 if not gpu: assert_eq(df, expected_df, check_dtype=False) int_input = 1203073300 df = c.sql(f"SELECT to_timestamp({int_input}) as date") expected_df = pd.DataFrame( { "date": [ datetime(2008, 2, 15, 11, 1, 40), ], } ) assert_eq(df, expected_df, check_dtype=False) string_input = "1997-02-28 10:30:00" df = c.sql(f"SELECT to_timestamp('{string_input}') as date") expected_df = pd.DataFrame( { "date": [ datetime(1997, 2, 28, 10, 30, 0), ], } ) assert_eq(df, expected_df, check_dtype=False) string_input = "02/28/1997" df = c.sql(f"SELECT to_timestamp('{string_input}', '%m/%d/%Y') as date") expected_df = pd.DataFrame( { "date": [ datetime(1997, 2, 28, 0, 0, 0), ], } ) assert_eq(df, expected_df, check_dtype=False) @pytest.mark.parametrize( "gpu", [ False, pytest.param( True, marks=(pytest.mark.gpu,), ), ], ) def test_extract_date(c, gpu): df = pd.DataFrame( { "a": [1, 2, 3], "b": [4, 5, 6], } ) df["t"] = [datetime(2021, 1, 1), datetime(2022, 2, 2), datetime(2023, 3, 3)] c.create_table("df", df, gpu=gpu) result = c.sql("SELECT EXTRACT(DATE FROM t) AS e FROM df") expected_df = pd.DataFrame( {"e": [datetime(2021, 1, 1), datetime(2022, 2, 2), datetime(2023, 3, 3)]} ) assert_eq(result, expected_df) result = c.sql("SELECT * FROM df WHERE EXTRACT(DATE FROM t) > '2021-02-01'") expected_df = pd.DataFrame( { "a": [2, 3], "b": [5, 6], "t": [datetime(2022, 2, 2), datetime(2023, 3, 3)], } ) assert_eq(result, expected_df, check_index=False) result = c.sql( "SELECT * FROM df WHERE EXTRACT(DATE FROM t) BETWEEN '2020-10-01' AND '2022-10-10'" ) expected_df = pd.DataFrame( {"a": [1, 2], "b": [4, 5], "t": [datetime(2021, 1, 1), datetime(2022, 2, 2)]} ) assert_eq(result, expected_df) result = c.sql("SELECT TIMESTAMPADD(YEAR, 1, EXTRACT(DATE FROM t)) AS ta FROM df") expected_df = pd.DataFrame( {"ta": [datetime(2022, 1, 1), datetime(2023, 2, 2), datetime(2024, 3, 3)]} ) assert_eq(result, expected_df) result = c.sql("SELECT EXTRACT(DATE FROM t) + INTERVAL '2 days' AS i FROM df") expected_df = pd.DataFrame( {"i": [datetime(2021, 1, 3), datetime(2022, 2, 4), datetime(2023, 3, 5)]} ) assert_eq(result, expected_df) @pytest.mark.parametrize( "gpu", [ False, pytest.param( True, marks=(pytest.mark.gpu,), ), ], ) def test_scalar_timestamps(c, gpu): df = pd.DataFrame({"d": [1203073300, 1503073700]}) c.create_table("df", df, gpu=gpu) expected_df = pd.DataFrame( { "dt": [datetime(2008, 2, 20, 11, 1, 40), datetime(2017, 8, 23, 16, 28, 20)], } ) df1 = c.sql("SELECT to_timestamp(d) + INTERVAL '5 days' AS dt FROM df") assert_eq(df1, expected_df) df2 = c.sql("SELECT CAST(d AS TIMESTAMP) + INTERVAL '5 days' AS dt FROM df") assert_eq(df2, expected_df) df1 = c.sql("SELECT TIMESTAMPADD(DAY, 5, to_timestamp(d)) AS dt FROM df") assert_eq(df1, expected_df) df2 = c.sql("SELECT TIMESTAMPADD(DAY, 5, d) AS dt FROM df") assert_eq(df2, expected_df) df3 = c.sql("SELECT TIMESTAMPADD(DAY, 5, CAST(d AS TIMESTAMP)) AS dt FROM df") assert_eq(df3, expected_df) expected_df = pd.DataFrame({"day": [15, 18]}) df1 = c.sql("SELECT EXTRACT(DAY FROM to_timestamp(d)) AS day FROM df") assert_eq(df1, expected_df, check_dtype=False) df2 = c.sql("SELECT EXTRACT(DAY FROM CAST(d AS TIMESTAMP)) AS day FROM df") assert_eq(df2, expected_df, check_dtype=False) expected_df = pd.DataFrame( { "ceil_to_day": [datetime(2008, 2, 16), datetime(2017, 8, 19)], } ) df1 = c.sql("SELECT CEIL(to_timestamp(d) TO DAY) AS ceil_to_day FROM df") assert_eq(df1, expected_df, check_dtype=(not gpu)) df2 = c.sql("SELECT CEIL(CAST(d AS TIMESTAMP) TO DAY) AS ceil_to_day FROM df") assert_eq(df2, expected_df) expected_df = pd.DataFrame( { "floor_to_day": [datetime(2008, 2, 15), datetime(2017, 8, 18)], } ) df1 = c.sql("SELECT FLOOR(to_timestamp(d) TO DAY) AS floor_to_day FROM df") assert_eq(df1, expected_df, check_dtype=(not gpu)) df2 = c.sql("SELECT FLOOR(CAST(d AS TIMESTAMP) TO DAY) AS floor_to_day FROM df") assert_eq(df2, expected_df) df = pd.DataFrame({"d1": [1203073300], "d2": [1503073700]}) c.create_table("df", df, gpu=gpu) expected_df = pd.DataFrame({"dt": [3472]}) df1 = c.sql( "SELECT TIMESTAMPDIFF(DAY, to_timestamp(d1), to_timestamp(d2)) AS dt FROM df" ) # TODO: The GPU case returns an incorrect value here if not gpu: assert_eq(df1, expected_df) df2 = c.sql("SELECT TIMESTAMPDIFF(DAY, d1, d2) AS dt FROM df") assert_eq(df2, expected_df, check_dtype=False) df3 = c.sql( "SELECT TIMESTAMPDIFF(DAY, CAST(d1 AS TIMESTAMP), CAST(d2 AS TIMESTAMP)) AS dt FROM df" ) assert_eq(df3, expected_df) scalar1 = 1203073300 scalar2 = 1503073700 expected_df = pd.DataFrame({"dt": [datetime(2008, 2, 20, 11, 1, 40)]}) df1 = c.sql(f"SELECT to_timestamp({scalar1}) + INTERVAL '5 days' AS dt") assert_eq(df1, expected_df) # TODO: Fix seconds/nanoseconds conversion # df2 = c.sql(f"SELECT CAST({scalar1} AS TIMESTAMP) + INTERVAL '5 days' AS dt") # assert_eq(df2, expected_df) df1 = c.sql(f"SELECT TIMESTAMPADD(DAY, 5, to_timestamp({scalar1})) AS dt") assert_eq(df1, expected_df) df2 = c.sql(f"SELECT TIMESTAMPADD(DAY, 5, {scalar1}) AS dt") assert_eq(df2, expected_df) df3 = c.sql(f"SELECT TIMESTAMPADD(DAY, 5, CAST({scalar1} AS TIMESTAMP)) AS dt") assert_eq(df3, expected_df) expected_df = pd.DataFrame({"day": [15]}) df1 = c.sql(f"SELECT EXTRACT(DAY FROM to_timestamp({scalar1})) AS day") assert_eq(df1, expected_df, check_dtype=False) # TODO: Fix seconds/nanoseconds conversion # df2 = c.sql(f"SELECT EXTRACT(DAY FROM CAST({scalar1} AS TIMESTAMP)) AS day") # assert_eq(df2, expected_df, check_dtype=False) expected_df = pd.DataFrame({"ceil_to_day": [datetime(2008, 2, 16)]}) df1 = c.sql(f"SELECT CEIL(to_timestamp({scalar1}) TO DAY) AS ceil_to_day") assert_eq(df1, expected_df) df2 = c.sql(f"SELECT CEIL(CAST({scalar1} AS TIMESTAMP) TO DAY) AS ceil_to_day") assert_eq(df2, expected_df) expected_df = pd.DataFrame({"floor_to_day": [datetime(2008, 2, 15)]}) df1 = c.sql(f"SELECT FLOOR(to_timestamp({scalar1}) TO DAY) AS floor_to_day") assert_eq(df1, expected_df) df2 = c.sql(f"SELECT FLOOR(CAST({scalar1} AS TIMESTAMP) TO DAY) AS floor_to_day") assert_eq(df2, expected_df) expected_df = pd.DataFrame({"dt": [3472]}) df1 = c.sql( f"SELECT TIMESTAMPDIFF(DAY, to_timestamp({scalar1}), to_timestamp({scalar2})) AS dt" ) assert_eq(df1, expected_df) df2 = c.sql(f"SELECT TIMESTAMPDIFF(DAY, {scalar1}, {scalar2}) AS dt") assert_eq(df2, expected_df, check_dtype=False) df3 = c.sql( f"SELECT TIMESTAMPDIFF(DAY, CAST({scalar1} AS TIMESTAMP), CAST({scalar2} AS TIMESTAMP)) AS dt" ) assert_eq(df3, expected_df) def test_datetime_coercion(c): d_table = pd.DataFrame( { "d_date": [ datetime(2023, 7, 1), datetime(2023, 7, 5), datetime(2023, 7, 10), datetime(2023, 7, 15), ], "x": [1, 2, 3, 4], } ) c.create_table("d_table", d_table) df = c.sql( """ SELECT * FROM d_table d1, d_table d2 WHERE d2.x < d1.x + (1 + 2) AND d2.d_date > d1.d_date + (2 + 3) """ ) expected_df = c.sql( """ SELECT * FROM d_table d1, d_table d2 WHERE d2.x < d1.x + (1 + 2) AND d2.d_date > d1.d_date + INTERVAL '5 days' """ ) assert_eq(df, expected_df) ================================================ FILE: tests/integration/test_sample.py ================================================ import numpy as np import pytest from tests.utils import assert_eq def get_system_sample(df, fraction, seed): random_state = np.random.RandomState(seed) random_choice = random_state.choice( [True, False], size=df.npartitions, replace=True, p=[fraction, 1 - fraction], ) if random_choice.any(): df = df.partitions[random_choice] else: df = df.head(0, compute=False) return df @pytest.mark.xfail(reason="WIP DataFusion") def test_sample(c, df): ddf = c.sql("SELECT * FROM df") # fixed system samples assert_eq( c.sql("SELECT * FROM df TABLESAMPLE SYSTEM (20) REPEATABLE (10)"), get_system_sample(ddf, 0.20, 10), ) assert_eq( c.sql("SELECT * FROM df TABLESAMPLE SYSTEM (20) REPEATABLE (11)"), get_system_sample(ddf, 0.20, 11), ) assert_eq( c.sql("SELECT * FROM df TABLESAMPLE SYSTEM (50) REPEATABLE (10)"), get_system_sample(ddf, 0.50, 10), ) assert_eq( c.sql("SELECT * FROM df TABLESAMPLE SYSTEM (0.001) REPEATABLE (10)"), get_system_sample(ddf, 0.00001, 10), ) assert_eq( c.sql("SELECT * FROM df TABLESAMPLE SYSTEM (99.999) REPEATABLE (10)"), get_system_sample(ddf, 0.99999, 10), ) # fixed bernoulli samples assert_eq( c.sql("SELECT * FROM df TABLESAMPLE BERNOULLI (50) REPEATABLE (10)"), ddf.sample(frac=0.50, replace=False, random_state=10), ) assert_eq( c.sql("SELECT * FROM df TABLESAMPLE BERNOULLI (70) REPEATABLE (10)"), ddf.sample(frac=0.70, replace=False, random_state=10), ) assert_eq( c.sql("SELECT * FROM df TABLESAMPLE BERNOULLI (0.001) REPEATABLE (10)"), ddf.sample(frac=0.00001, replace=False, random_state=10), ) assert_eq( c.sql("SELECT * FROM df TABLESAMPLE BERNOULLI (99.999) REPEATABLE (10)"), ddf.sample(frac=0.99999, replace=False, random_state=10), ) # variable samples, can only check boundaries return_df = c.sql("SELECT * FROM df TABLESAMPLE BERNOULLI (50)") assert len(return_df) >= 0 and len(return_df) <= len(df) return_df = c.sql("SELECT * FROM df TABLESAMPLE SYSTEM (50)") assert len(return_df) >= 0 and len(return_df) <= len(df) ================================================ FILE: tests/integration/test_schema.py ================================================ import dask.dataframe as dd import numpy as np import pytest from dask_sql.utils import ParsingException from tests.utils import assert_eq @pytest.mark.xfail(reason="WIP DataFusion") def test_table_schema(c, df): original_df = c.sql("SELECT * FROM df") assert_eq(original_df, c.sql("SELECT * FROM root.df")) c.sql("CREATE SCHEMA foo") assert_eq(original_df, c.sql("SELECT * FROM df")) c.sql('USE SCHEMA "foo"') assert_eq(original_df, c.sql("SELECT * FROM root.df")) c.sql("CREATE TABLE bar AS TABLE root.df") assert_eq(original_df, c.sql("SELECT * FROM bar")) with pytest.raises(KeyError): c.sql("CREATE TABLE other.bar AS TABLE df") c.sql('USE SCHEMA "root"') assert_eq(original_df, c.sql("SELECT * FROM foo.bar")) with pytest.raises(ParsingException): c.sql("SELECT * FROM bar") c.sql("DROP SCHEMA foo") with pytest.raises(ParsingException): c.sql("SELECT * FROM foo.bar") @pytest.mark.xfail(reason="WIP DataFusion") def test_function(c): c.sql("CREATE SCHEMA other") c.sql("USE SCHEMA root") def f(x): return x**2 c.register_function(f, "f", [("x", np.float64)], np.float64, schema_name="other") with pytest.raises(ParsingException): c.sql("SELECT F(a) AS a FROM df") c.sql("SELECT other.F(a) AS a FROM df") c.sql("USE SCHEMA other") c.sql("SELECT F(a) AS a FROM root.df") c.sql("USE SCHEMA root") fagg = dd.Aggregation("f", lambda x: x.sum(), lambda x: x.sum()) c.register_aggregation( fagg, "fagg", [("x", np.float64)], np.float64, schema_name="other" ) with pytest.raises(ParsingException): c.sql("SELECT FAGG(b) AS test FROM df") c.sql("SELECT other.FAGG(b) AS test FROM df") c.sql("USE SCHEMA other") c.sql("SELECT FAGG(b) AS test FROM root.df") def test_create_schema(c): c.sql("CREATE SCHEMA new_schema") assert "new_schema" in c.schema with pytest.raises(RuntimeError): c.sql("CREATE SCHEMA new_schema") c.sql("CREATE OR REPLACE SCHEMA new_schema") c.sql("CREATE SCHEMA IF NOT EXISTS new_schema") def test_drop_schema(c): with pytest.raises(RuntimeError): c.sql("DROP SCHEMA new_schema") c.sql("DROP SCHEMA IF EXISTS new_schema") c.sql("CREATE SCHEMA new_schema") c.sql("DROP SCHEMA IF EXISTS new_schema") with pytest.raises(RuntimeError): c.sql("USE SCHEMA new_schema") with pytest.raises(RuntimeError): c.sql("DROP SCHEMA root") c.sql("CREATE SCHEMA example") c.sql("USE SCHEMA example") c.sql("DROP SCHEMA example") assert c.schema_name == c.DEFAULT_SCHEMA_NAME assert "example" not in c.schema ================================================ FILE: tests/integration/test_select.py ================================================ import numpy as np import pandas as pd import pytest from dask.dataframe.optimize import optimize_dataframe_getitem from dask.utils_test import hlg_layer from dask_sql.utils import ParsingException from tests.utils import assert_eq, skipif_dask_expr_enabled def test_select(c, df): result_df = c.sql("SELECT * FROM df") assert_eq(result_df, df) def test_select_alias(c, df): result_df = c.sql("SELECT a as b, b as a FROM df") expected_df = pd.DataFrame(index=df.index) expected_df["b"] = df.a expected_df["a"] = df.b assert_eq(result_df[["a", "b"]], expected_df[["a", "b"]]) def test_select_column(c, df): result_df = c.sql("SELECT a FROM df") assert_eq(result_df, df[["a"]]) def test_select_different_types(c): expected_df = pd.DataFrame( { "date": pd.to_datetime( ["2022-01-21 17:34", "2022-01-21", "17:34", pd.NaT], format="mixed", ), "string": ["this is a test", "another test", "äölüć", ""], "integer": [1, 2, -4, 5], "float": [-1.1, np.NaN, pd.NA, np.sqrt(2)], } ) c.create_table("df", expected_df) result_df = c.sql( """ SELECT * FROM df """ ) assert_eq(result_df, expected_df) def test_select_expr(c, df): result_df = c.sql("SELECT a + 1 AS a, b AS bla, a - 1 FROM df") result_df = result_df expected_df = pd.DataFrame( { "a": df["a"] + 1, "bla": df["b"], "df.a - Int64(1)": df["a"] - 1, } ) assert_eq(result_df, expected_df) def test_select_of_select(c, df): result_df = c.sql( """ SELECT 2*c AS e, d - 1 AS f FROM ( SELECT a - 1 AS c, 2*b AS d FROM df ) AS "inner" """ ) expected_df = pd.DataFrame({"e": 2 * (df["a"] - 1), "f": 2 * df["b"] - 1}) assert_eq(result_df, expected_df) @pytest.mark.xfail( reason="Column casing doesn't work as expected with datafusion>21, " "https://github.com/apache/arrow-datafusion/issues/5626" ) def test_select_of_select_with_casing(c, df): result_df = c.sql( """ SELECT "AAA", "aaa", "aAa" FROM ( SELECT a - 1 AS "aAa", 2*b AS "aaa", a + b AS "AAA" FROM df ) AS "inner" """ ) expected_df = pd.DataFrame( {"AAA": df["a"] + df["b"], "aaa": 2 * df["b"], "aAa": df["a"] - 1} ) assert_eq(result_df, expected_df) def test_wrong_input(c): with pytest.raises(ParsingException): c.sql("""SELECT x FROM df""") with pytest.raises(ParsingException): c.sql("""SELECT x FROM df""") def test_timezones(c, datetime_table): result_df = c.sql( """ SELECT * FROM datetime_table """ ) assert_eq(result_df, datetime_table) @pytest.mark.parametrize( "input_table", [ "long_table", pytest.param("gpu_long_table", marks=pytest.mark.gpu), ], ) @pytest.mark.parametrize( "limit,offset", [(101, 0), (200, 0), (100, 0), (100, 99), (100, 100), (101, 101), (0, 101)], ) def test_limit(c, input_table, limit, offset, request): long_table = request.getfixturevalue(input_table) if not limit: query = f"SELECT * FROM long_table OFFSET {offset}" else: query = f"SELECT * FROM long_table LIMIT {limit} OFFSET {offset}" assert_eq(c.sql(query), long_table.iloc[offset : offset + limit if limit else None]) @pytest.mark.parametrize( "input_table", [ "datetime_table", pytest.param("gpu_datetime_table", marks=pytest.mark.gpu), ], ) def test_date_casting(c, input_table, request): datetime_table = request.getfixturevalue(input_table) result_df = c.sql( f""" SELECT CAST(timezone AS DATE) AS timezone, CAST(no_timezone AS DATE) AS no_timezone, CAST(utc_timezone AS DATE) AS utc_timezone FROM {input_table} """ ) expected_df = datetime_table expected_df["timezone"] = ( expected_df["timezone"].dt.tz_localize(None).dt.floor("D").astype(" 0.5 """ ) assert_query_gives_same_result( """ SELECT d FROM df2 WHERE d NOT LIKE '%c' """ ) assert_query_gives_same_result( """ SELECT d FROM df2 WHERE d = 'a' """ ) assert_query_gives_same_result( """ SELECT * FROM df1 WHERE 1 < a AND a < 5 """ ) assert_query_gives_same_result( """ SELECT * FROM df1 WHERE a < 5 AND b < 5 """ ) assert_query_gives_same_result( """ SELECT * FROM df1 WHERE a + b > 5 """ ) ================================================ FILE: tests/integration/test_union.py ================================================ import pandas as pd from tests.utils import assert_eq def test_union_not_all(c, df): result_df = c.sql( """ SELECT * FROM df UNION SELECT * FROM df UNION SELECT * FROM df """ ) assert_eq(result_df, df, check_index=False) def test_union_all(c, df): result_df = c.sql( """ SELECT * FROM df UNION ALL SELECT * FROM df UNION ALL SELECT * FROM df """ ) expected_df = pd.concat([df, df, df], ignore_index=True) assert_eq(result_df, expected_df, check_index=False) def test_union_mixed(c, df, long_table): result_df = c.sql( """ SELECT a AS "I", b as "II" FROM df UNION ALL SELECT a as "I", a as "II" FROM long_table """ ) long_table = long_table.rename(columns={"a": "I"}) long_table["II"] = long_table["I"] expected_df = pd.concat( [df.rename(columns={"a": "I", "b": "II"}), long_table], ignore_index=True, ) assert_eq(result_df, expected_df, check_index=False) ================================================ FILE: tests/unit/__init__.py ================================================ ================================================ FILE: tests/unit/test_call.py ================================================ import datetime import operator from unittest.mock import MagicMock import dask.dataframe as dd import numpy as np import pandas as pd import dask_sql.physical.rex.core.call as call from tests.utils import assert_eq df1 = dd.from_pandas(pd.DataFrame({"a": [1, 2, 3]}), npartitions=1) df2 = dd.from_pandas(pd.DataFrame({"a": [3, 2, 1]}), npartitions=1) df3 = dd.from_pandas( pd.DataFrame({"a": [True, pd.NA, False]}, dtype="boolean"), npartitions=1 ) ops_mapping = call.RexCallPlugin.OPERATION_MAPPING def test_operation(): operator = MagicMock() operator.return_value = "test" op = call.Operation(operator) assert op("input") == "test" operator.assert_called_once_with("input") def test_reduce(): op = call.ReduceOperation(operator.add) assert op(1, 2, 3) == 6 def test_case(): op = call.CaseOperation() assert_eq(op(df1.a > 2, df1.a, df2.a), pd.Series([3, 2, 3]), check_names=False) assert_eq(op(df1.a > 2, 99, df2.a), pd.Series([3, 2, 99]), check_names=False) assert_eq(op(df1.a > 2, 99, -1), pd.Series([-1, -1, 99]), check_names=False) assert_eq(op(df1.a > 2, df1.a, -1), pd.Series([-1, -1, 3]), check_names=False) assert op(True, 1, 2) == 1 assert op(False, 1, 2) == 2 def test_is_true(): op = call.IsTrueOperation() assert_eq( op(df1.a > 2), pd.Series([False, False, True]), check_names=False, check_dtype=False, ) assert_eq( op(df3.a), pd.Series([True, False, False]), check_names=False, check_dtype=False, ) assert op(1) assert not op(0) assert not op(None) assert not op(np.NaN) assert not op(pd.NA) def test_is_false(): op = call.IsFalseOperation() assert_eq( op(df1.a > 2), pd.Series([True, True, False]), check_names=False, check_dtype=False, ) assert_eq( op(df3.a), pd.Series([False, False, True]), check_names=False, check_dtype=False, ) assert not op(1) assert op(0) assert not op(None) assert not op(np.NaN) assert not op(pd.NA) def test_like(): op = call.LikeOperation() assert op("a string", r"%a%") assert op("another string", r"a%") assert not op("another string", r"s%") op = call.SimilarOperation() assert op("normal", r"n[a-z]rm_l") assert not op("not normal", r"n[a-z]rm_l") def test_not(): op = call.NotOperation() assert op(False) assert not op(True) assert not op(3) def test_nan(): op = call.IsNullOperation() assert op(None) assert op(np.NaN) assert op(pd.NA) assert_eq(op(pd.Series(["a", None, "c"])), pd.Series([False, True, False])) assert_eq( op(pd.Series([3, 2, np.NaN, pd.NA])), pd.Series([False, False, True, True]) ) def test_simple_ops(): assert_eq( ops_mapping["and"](df1.a >= 2, df2.a >= 2), pd.Series([False, True, False]), check_names=False, ) assert_eq( ops_mapping["or"](df1.a >= 2, df2.a >= 2), pd.Series([True, True, True]), check_names=False, ) assert_eq( ops_mapping[">="](df1.a, df2.a), pd.Series([False, True, True]), check_names=False, ) assert_eq( ops_mapping["+"](df1.a, df2.a, df1.a), pd.Series([5, 6, 7]), check_names=False, ) def test_math_operations(): assert_eq( ops_mapping["abs"](-df1.a), pd.Series([1, 2, 3]), check_names=False, ) assert_eq( ops_mapping["round"](df1.a), pd.Series([1, 2, 3]), check_names=False, ) assert_eq( ops_mapping["floor"](df1.a), pd.Series([1.0, 2.0, 3.0]), check_names=False, ) assert ops_mapping["abs"](-5) == 5 assert ops_mapping["round"](1.234, 2) == 1.23 assert ops_mapping["floor"](1.234) == 1 def test_string_operations(): a = "a normal string" assert ops_mapping["characterlength"](a) == 15 assert ops_mapping["upper"](a) == "A NORMAL STRING" assert ops_mapping["lower"](a) == "a normal string" assert ops_mapping["position"]("a", a, 4) == 7 assert ops_mapping["position"]("ZL", a) == 0 assert ops_mapping["trim"](a, "a") == " normal string" assert ops_mapping["btrim"](a, "a") == " normal string" assert ops_mapping["ltrim"](a, "a") == " normal string" assert ops_mapping["rtrim"](a, "a") == "a normal string" assert ops_mapping["overlay"](a, "XXX", 2) == "aXXXrmal string" assert ops_mapping["overlay"](a, "XXX", 2, 4) == "aXXXmal string" assert ops_mapping["overlay"](a, "XXX", 2, 1) == "aXXXnormal string" assert ops_mapping["substring"](a, -1) == "a normal string" assert ops_mapping["substring"](a, 10) == "string" assert ops_mapping["substring"](a, 2) == " normal string" assert ops_mapping["substring"](a, 2, 2) == " n" assert ops_mapping["initcap"](a) == "A Normal String" assert ops_mapping["replace"](a, "nor", "") == "a mal string" assert ops_mapping["replace"](a, "normal", "new") == "a new string" assert ops_mapping["replace"]("hello", "", "w") == "whwewlwlwow" def test_dates(): op = call.ExtractOperation() date = datetime.datetime(2021, 10, 3, 15, 53, 42, 47) assert int(op("CENTURY", date)) == 20 assert op("DAY", date) == 3 assert int(op("DECADE", date)) == 202 assert op("DOW", date) == 0 assert op("DOY", date) == 276 assert op("HOUR", date) == 15 assert op("MICROSECOND", date) == 47 assert op("MILLENNIUM", date) == 2 assert op("MILLISECOND", date) == 47000 assert op("MINUTE", date) == 53 assert op("MONTH", date) == 10 assert op("QUARTER", date) == 4 assert op("SECOND", date) == 42 assert op("WEEK", date) == 39 assert op("YEAR", date) == 2021 assert op("DATE", date) == datetime.date(2021, 10, 3) ceil_op = call.CeilFloorOperation("ceil") floor_op = call.CeilFloorOperation("floor") assert ceil_op(date, "DAY") == datetime.datetime(2021, 10, 4) assert ceil_op(date, "HOUR") == datetime.datetime(2021, 10, 3, 16) assert ceil_op(date, "MINUTE") == datetime.datetime(2021, 10, 3, 15, 54) assert ceil_op(date, "SECOND") == datetime.datetime(2021, 10, 3, 15, 53, 43) assert ceil_op(date, "MILLISECOND") == datetime.datetime( 2021, 10, 3, 15, 53, 42, 1000 ) assert floor_op(date, "DAY") == datetime.datetime(2021, 10, 3) assert floor_op(date, "HOUR") == datetime.datetime(2021, 10, 3, 15) assert floor_op(date, "MINUTE") == datetime.datetime(2021, 10, 3, 15, 53) assert floor_op(date, "SECOND") == datetime.datetime(2021, 10, 3, 15, 53, 42) assert floor_op(date, "MILLISECOND") == datetime.datetime(2021, 10, 3, 15, 53, 42) ================================================ FILE: tests/unit/test_config.py ================================================ import os import sys from unittest import mock import dask.dataframe as dd import pandas as pd import pytest import yaml from dask import config as dask_config # Required to instantiate default sql config import dask_sql # noqa: F401 from dask_sql import Context from tests.utils import skipif_dask_expr_enabled def test_custom_yaml(tmpdir): custom_config = {} custom_config["sql"] = dask_config.get("sql") custom_config["sql"]["aggregate"]["split_out"] = 16 custom_config["sql"]["foo"] = {"bar": [1, 2, 3], "baz": None} with open(os.path.join(tmpdir, "custom-sql.yaml"), mode="w") as f: yaml.dump(custom_config, f) dask_config.refresh( paths=[tmpdir] ) # Refresh config to read from updated environment assert custom_config["sql"] == dask_config.get("sql") dask_config.refresh() def test_env_variable(): with mock.patch.dict("os.environ", {"DASK_SQL__AGGREGATE__SPLIT_OUT": "200"}): dask_config.refresh() assert dask_config.get("sql.aggregate.split-out") == 200 dask_config.refresh() def test_default_config(): config_fn = os.path.join(os.path.dirname(__file__), "../../dask_sql", "sql.yaml") with open(config_fn) as f: default_config = yaml.safe_load(f) assert "sql" in default_config assert default_config["sql"] == dask_config.get("sql") def test_schema(): jsonschema = pytest.importorskip("jsonschema") config_fn = os.path.join(os.path.dirname(__file__), "../../dask_sql", "sql.yaml") schema_fn = os.path.join( os.path.dirname(__file__), "../../dask_sql", "sql-schema.yaml" ) with open(config_fn) as f: config = yaml.safe_load(f) with open(schema_fn) as f: schema = yaml.safe_load(f) jsonschema.validate(config, schema) def test_schema_is_complete(): config_fn = os.path.join(os.path.dirname(__file__), "../../dask_sql", "sql.yaml") schema_fn = os.path.join( os.path.dirname(__file__), "../../dask_sql", "sql-schema.yaml" ) with open(config_fn) as f: config = yaml.safe_load(f) with open(schema_fn) as f: schema = yaml.safe_load(f) def test_matches(c, s): for k, v in c.items(): if list(c) != list(s["properties"]): raise ValueError( "\nThe sql.yaml and sql-schema.yaml files are not in sync.\n" "This usually happens when we add a new configuration value,\n" "but don't add the schema of that value to the dask-schema.yaml file\n" "Please modify these files to include the missing values: \n\n" " sql.yaml: {}\n" " sql-schema.yaml: {}\n\n" "Examples in these files should be a good start, \n" "even if you are not familiar with the jsonschema spec".format( sorted(c), sorted(s["properties"]) ) ) if isinstance(v, dict): test_matches(c[k], s["properties"][k]) test_matches(config, schema) def test_dask_setconfig(): dask_config.set({"sql.foo.bar": 1}) with dask_config.set({"sql.foo.baz": "2"}): assert dask_config.get("sql.foo") == {"bar": 1, "baz": "2"} assert dask_config.get("sql.foo") == {"bar": 1} dask_config.refresh() @pytest.mark.skipif( sys.version_info < (3, 10), reason="Writing and reading the Dask DataFrame causes a ProtocolError", ) @skipif_dask_expr_enabled("dynamic partition pruning not yet supported with dask-expr") def test_dynamic_partition_pruning(tmpdir): c = Context() df1 = pd.DataFrame( { "x": [1, 2, 3], "z": [7, 8, 9], }, ) dd.from_pandas(df1, npartitions=3).to_parquet(os.path.join(tmpdir, "df1")) df1 = dd.read_parquet(os.path.join(tmpdir, "df1")) c.create_table("df1", df1) df2 = pd.DataFrame( { "x": [1, 2, 3] * 1000, "y": [4, 5, 6] * 1000, }, ) dd.from_pandas(df2, npartitions=3).to_parquet(os.path.join(tmpdir, "df2")) df2 = dd.read_parquet(os.path.join(tmpdir, "df2")) c.create_table("df2", df2) query = "SELECT * FROM df1, df2 WHERE df1.x = df2.x AND df1.z=7" inlist_expr = "df2.x IN ([Int64(1)])" # Default value is False dask_config.set({"sql.optimizer.verbose": True}) # When DPP is turned off, the explain output will not contain the INLIST expression dask_config.set({"sql.dynamic_partition_pruning": False}) explain_string = c.explain(query) assert inlist_expr not in explain_string # When DPP is turned on but sql.optimizer.verbose is off, the explain output will not contain the # INLIST expression dask_config.set({"sql.dynamic_partition_pruning": True}) dask_config.set({"sql.optimizer.verbose": False}) explain_string = c.explain(query) assert inlist_expr not in explain_string # When both DPP and sql.optimizer.verbose are turned on, the explain output will contain the INLIST # expression dask_config.set({"sql.dynamic_partition_pruning": True}) dask_config.set({"sql.optimizer.verbose": True}) explain_string = c.explain(query) assert inlist_expr in explain_string @skipif_dask_expr_enabled("dynamic partition pruning not yet supported with dask-expr") def test_dpp_single_file_parquet(tmpdir): c = Context() dask_config.set({"sql.dynamic_partition_pruning": True}) dask_config.set({"sql.optimizer.verbose": True}) df1 = pd.DataFrame( { "x": [1, 2, 3], "z": [7, 8, 9], }, ) dd.from_pandas(df1, npartitions=1).to_parquet( os.path.join(tmpdir, "df1_single_file") ) df1 = dd.read_parquet(os.path.join(tmpdir, "df1_single_file/part.0.parquet")) c.create_table("df1", df1) df2 = pd.DataFrame( { "x": [1, 2, 3] * 1000, "y": [4, 5, 6] * 1000, }, ) dd.from_pandas(df2, npartitions=3).to_parquet(os.path.join(tmpdir, "df2")) df2 = dd.read_parquet(os.path.join(tmpdir, "df2")) c.create_table("df2", df2) query = "SELECT * FROM df1, df2 WHERE df1.x = df2.x AND df1.z=7" inlist_expr = "df2.x IN ([Int64(1)])" explain_string = c.explain(query) assert inlist_expr in explain_string ================================================ FILE: tests/unit/test_context.py ================================================ import os import sys import dask.dataframe as dd import pandas as pd import pytest from dask_sql import Context from tests.utils import assert_eq try: import cudf import dask_cudf except ImportError: cudf = None dask_cudf = None # default integer type varies by platform DEFAULT_INT_TYPE = "INTEGER" if sys.platform == "win32" else "BIGINT" @pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)]) def test_add_remove_tables(gpu): c = Context() data_frame = dd.from_pandas(pd.DataFrame(), npartitions=1) c.create_table("table", data_frame, gpu=gpu) assert "table" in c.schema[c.schema_name].tables c.drop_table("table") assert "table" not in c.schema[c.schema_name].tables with pytest.raises(KeyError): c.drop_table("table") c.create_table("table", [data_frame], gpu=gpu) assert "table" in c.schema[c.schema_name].tables @pytest.mark.parametrize( "gpu", [ False, pytest.param( True, marks=pytest.mark.gpu, ), ], ) def test_sql(gpu): c = Context() data_frame = dd.from_pandas(pd.DataFrame({"a": [1, 2, 3]}), npartitions=1) c.create_table("df", data_frame, gpu=gpu) result = c.sql("SELECT * FROM df") assert isinstance(result, dd.DataFrame) assert_eq(result, data_frame) result = c.sql("SELECT * FROM df", return_futures=False) assert not isinstance(result, dd.DataFrame) assert_eq(result, data_frame) result = c.sql( "SELECT * FROM other_df", dataframes={"other_df": data_frame}, gpu=gpu ) assert isinstance(result, dd.DataFrame) assert_eq(result, data_frame) @pytest.mark.parametrize( "gpu", [ False, pytest.param( True, marks=pytest.mark.gpu, ), ], ) def test_input_types(temporary_data_file, gpu): c = Context() df = pd.DataFrame({"a": [1, 2, 3]}) def assert_correct_output(gpu): result = c.sql("SELECT * FROM df") assert isinstance(result, dd.DataFrame if not gpu else dask_cudf.DataFrame) assert_eq(result, df) c.create_table("df", df, gpu=gpu) assert_correct_output(gpu=gpu) c.create_table("df", dd.from_pandas(df, npartitions=1), gpu=gpu) assert_correct_output(gpu=gpu) df.to_csv(temporary_data_file, index=False) c.create_table("df", temporary_data_file, gpu=gpu) assert_correct_output(gpu=gpu) df.to_csv(temporary_data_file, index=False) c.create_table("df", temporary_data_file, format="csv", gpu=gpu) assert_correct_output(gpu=gpu) df.to_parquet(temporary_data_file, index=False) c.create_table("df", temporary_data_file, format="parquet", gpu=gpu) assert_correct_output(gpu=gpu) with pytest.raises(AttributeError): c.create_table("df", temporary_data_file, format="unknown", gpu=gpu) strangeThing = object() with pytest.raises(ValueError): c.create_table("df", strangeThing, gpu=gpu) @pytest.mark.parametrize( "gpu", [ False, pytest.param(True, marks=pytest.mark.gpu), ], ) def test_tables_from_stack(gpu): c = Context() assert not c._get_tables_from_stack() df = pd.DataFrame() if not gpu else cudf.DataFrame() assert "df" in c._get_tables_from_stack() def f(gpu): df2 = pd.DataFrame() if not gpu else cudf.DataFrame() assert "df" in c._get_tables_from_stack() assert "df2" in c._get_tables_from_stack() f(gpu=gpu) def g(gpu=gpu): df = pd.DataFrame({"a": [1]}) if not gpu else cudf.DataFrame({"a": [1]}) assert "df" in c._get_tables_from_stack() assert c._get_tables_from_stack()["df"].columns == ["a"] g(gpu=gpu) def test_function_adding(): c = Context() assert not c.schema[c.schema_name].function_lists assert not c.schema[c.schema_name].functions f = lambda x: x c.register_function(f, "f", [("x", int)], float) assert "f" in c.schema[c.schema_name].functions assert c.schema[c.schema_name].functions["f"].func == f assert len(c.schema[c.schema_name].function_lists) == 2 assert c.schema[c.schema_name].function_lists[0].name == "F" assert c.schema[c.schema_name].function_lists[0].parameters[0][0] == "x" assert ( str(c.schema[c.schema_name].function_lists[0].parameters[0][1]) == DEFAULT_INT_TYPE ) assert str(c.schema[c.schema_name].function_lists[0].return_type) == "DOUBLE" assert not c.schema[c.schema_name].function_lists[0].aggregation assert c.schema[c.schema_name].function_lists[1].name == "f" assert c.schema[c.schema_name].function_lists[1].parameters[0][0] == "x" assert ( str(c.schema[c.schema_name].function_lists[1].parameters[0][1]) == DEFAULT_INT_TYPE ) assert str(c.schema[c.schema_name].function_lists[1].return_type) == "DOUBLE" assert not c.schema[c.schema_name].function_lists[1].aggregation # Without replacement c.register_function(f, "f", [("x", float)], int, replace=False) assert "f" in c.schema[c.schema_name].functions assert c.schema[c.schema_name].functions["f"].func == f assert len(c.schema[c.schema_name].function_lists) == 4 assert c.schema[c.schema_name].function_lists[2].name == "F" assert c.schema[c.schema_name].function_lists[2].parameters[0][0] == "x" assert str(c.schema[c.schema_name].function_lists[2].parameters[0][1]) == "DOUBLE" assert ( str(c.schema[c.schema_name].function_lists[2].return_type) == DEFAULT_INT_TYPE ) assert not c.schema[c.schema_name].function_lists[2].aggregation assert c.schema[c.schema_name].function_lists[3].name == "f" assert c.schema[c.schema_name].function_lists[3].parameters[0][0] == "x" assert str(c.schema[c.schema_name].function_lists[3].parameters[0][1]) == "DOUBLE" assert ( str(c.schema[c.schema_name].function_lists[3].return_type) == DEFAULT_INT_TYPE ) assert not c.schema[c.schema_name].function_lists[3].aggregation # With replacement f = lambda x: x + 1 c.register_function(f, "f", [("x", str)], str, replace=True) assert "f" in c.schema[c.schema_name].functions assert c.schema[c.schema_name].functions["f"].func == f assert len(c.schema[c.schema_name].function_lists) == 2 assert c.schema[c.schema_name].function_lists[0].name == "F" assert c.schema[c.schema_name].function_lists[0].parameters[0][0] == "x" assert str(c.schema[c.schema_name].function_lists[0].parameters[0][1]) == "VARCHAR" assert str(c.schema[c.schema_name].function_lists[0].return_type) == "VARCHAR" assert not c.schema[c.schema_name].function_lists[0].aggregation assert c.schema[c.schema_name].function_lists[1].name == "f" assert c.schema[c.schema_name].function_lists[1].parameters[0][0] == "x" assert str(c.schema[c.schema_name].function_lists[1].parameters[0][1]) == "VARCHAR" assert str(c.schema[c.schema_name].function_lists[1].return_type) == "VARCHAR" assert not c.schema[c.schema_name].function_lists[1].aggregation def test_aggregation_adding(): c = Context() assert not c.schema[c.schema_name].function_lists assert not c.schema[c.schema_name].functions f = lambda x: x c.register_aggregation(f, "f", [("x", int)], float) assert "f" in c.schema[c.schema_name].functions assert c.schema[c.schema_name].functions["f"] == f assert len(c.schema[c.schema_name].function_lists) == 2 assert c.schema[c.schema_name].function_lists[0].name == "F" assert c.schema[c.schema_name].function_lists[0].parameters[0][0] == "x" assert ( str(c.schema[c.schema_name].function_lists[0].parameters[0][1]) == DEFAULT_INT_TYPE ) assert str(c.schema[c.schema_name].function_lists[0].return_type) == "DOUBLE" assert c.schema[c.schema_name].function_lists[0].aggregation assert c.schema[c.schema_name].function_lists[1].name == "f" assert c.schema[c.schema_name].function_lists[1].parameters[0][0] == "x" assert ( str(c.schema[c.schema_name].function_lists[1].parameters[0][1]) == DEFAULT_INT_TYPE ) assert str(c.schema[c.schema_name].function_lists[1].return_type) == "DOUBLE" assert c.schema[c.schema_name].function_lists[1].aggregation # Without replacement c.register_aggregation(f, "f", [("x", float)], int, replace=False) assert "f" in c.schema[c.schema_name].functions assert c.schema[c.schema_name].functions["f"] == f assert len(c.schema[c.schema_name].function_lists) == 4 assert c.schema[c.schema_name].function_lists[2].name == "F" assert c.schema[c.schema_name].function_lists[2].parameters[0][0] == "x" assert str(c.schema[c.schema_name].function_lists[2].parameters[0][1]) == "DOUBLE" assert ( str(c.schema[c.schema_name].function_lists[2].return_type) == DEFAULT_INT_TYPE ) assert c.schema[c.schema_name].function_lists[2].aggregation assert c.schema[c.schema_name].function_lists[3].name == "f" assert c.schema[c.schema_name].function_lists[3].parameters[0][0] == "x" assert str(c.schema[c.schema_name].function_lists[3].parameters[0][1]) == "DOUBLE" assert ( str(c.schema[c.schema_name].function_lists[3].return_type) == DEFAULT_INT_TYPE ) assert c.schema[c.schema_name].function_lists[3].aggregation # With replacement f = lambda x: x + 1 c.register_aggregation(f, "f", [("x", str)], str, replace=True) assert "f" in c.schema[c.schema_name].functions assert c.schema[c.schema_name].functions["f"] == f assert len(c.schema[c.schema_name].function_lists) == 2 assert c.schema[c.schema_name].function_lists[0].name == "F" assert c.schema[c.schema_name].function_lists[0].parameters[0][0] == "x" assert str(c.schema[c.schema_name].function_lists[0].parameters[0][1]) == "VARCHAR" assert str(c.schema[c.schema_name].function_lists[0].return_type) == "VARCHAR" assert c.schema[c.schema_name].function_lists[0].aggregation assert c.schema[c.schema_name].function_lists[1].name == "f" assert c.schema[c.schema_name].function_lists[1].parameters[0][0] == "x" assert str(c.schema[c.schema_name].function_lists[1].parameters[0][1]) == "VARCHAR" assert str(c.schema[c.schema_name].function_lists[1].return_type) == "VARCHAR" assert c.schema[c.schema_name].function_lists[1].aggregation def test_alter_schema(c): c.create_schema("test_schema") c.sql("ALTER SCHEMA test_schema RENAME TO prod_schema") assert "prod_schema" in c.schema assert "test_schema" not in c.schema with pytest.raises(KeyError): c.sql("ALTER SCHEMA MARVEL RENAME TO DC") del c.schema["prod_schema"] def test_alter_table(c, df_simple): c.create_table("maths", df_simple) c.sql("ALTER TABLE maths RENAME TO physics") assert "physics" in c.schema[c.schema_name].tables assert "maths" not in c.schema[c.schema_name].tables with pytest.raises(KeyError): c.sql("ALTER TABLE four_legs RENAME TO two_legs") c.sql("ALTER TABLE IF EXISTS alien RENAME TO humans") del c.schema[c.schema_name].tables["physics"] def test_filepath(tmpdir, parquet_ddf): c = Context() parquet_path = os.path.join(tmpdir, "parquet") # Create table with string (Parquet filepath) c.create_table("parquet_ddf", parquet_path, format="parquet") assert c.schema["root"].tables["parquet_ddf"].filepath == parquet_path assert c.schema["root"].filepaths["parquet_ddf"] == parquet_path df = pd.DataFrame({"a": [2, 1, 2, 3], "b": [3, 3, 1, 3]}) c.create_table("df", df) assert c.schema["root"].tables["df"].filepath is None with pytest.raises(KeyError): c.schema["root"].filepaths["df"] def test_ddf_filepath(tmpdir, parquet_ddf): c = Context() parquet_path = os.path.join(tmpdir, "parquet") # Create table with Dask DataFrame (created from read_parquet) c.create_table("parquet_ddf", parquet_ddf) assert c.schema["root"].tables["parquet_ddf"].filepath == parquet_path assert c.schema["root"].filepaths["parquet_ddf"] == parquet_path ================================================ FILE: tests/unit/test_datacontainer.py ================================================ from dask_sql.datacontainer import ColumnContainer def test_cc_init(): c = ColumnContainer(["a", "b", "c"]) assert c.columns == ["a", "b", "c"] assert c.mapping() == [("a", "a"), ("b", "b"), ("c", "c")] c = ColumnContainer(["a", "b", "c"], {"a": "1", "b": "2", "c": "3"}) assert c.columns == ["a", "b", "c"] assert c.mapping() == [("a", "1"), ("b", "2"), ("c", "3")] def test_cc_limit_to(): c = ColumnContainer(["a", "b", "c"]) c2 = c.limit_to(["c", "a"]) assert c2.columns == ["c", "a"] assert c2.mapping() == [("a", "a"), ("b", "b"), ("c", "c")] assert c.columns == ["a", "b", "c"] assert c.mapping() == [("a", "a"), ("b", "b"), ("c", "c")] def test_cc_rename(): c = ColumnContainer(["a", "b", "c"]) c2 = c.rename({"a": "A", "b": "a"}) assert c2.columns == ["A", "a", "c"] assert c2.mapping() == [("a", "b"), ("b", "b"), ("c", "c"), ("A", "a")] assert c.columns == ["a", "b", "c"] assert c.mapping() == [("a", "a"), ("b", "b"), ("c", "c")] def test_cc_add(): c = ColumnContainer(["a", "b", "c"]) c2 = c.add("d") assert c2.columns == ["a", "b", "c", "d"] assert c2.mapping() == [("a", "a"), ("b", "b"), ("c", "c"), ("d", "d")] assert c.columns == ["a", "b", "c"] assert c.mapping() == [("a", "a"), ("b", "b"), ("c", "c")] c2 = c.add("d", "D") assert c2.columns == ["a", "b", "c", "d"] assert c2.mapping() == [("a", "a"), ("b", "b"), ("c", "c"), ("d", "D")] assert c.columns == ["a", "b", "c"] assert c.mapping() == [("a", "a"), ("b", "b"), ("c", "c")] c2 = c.add("d", "a") assert c2.columns == ["a", "b", "c", "d"] assert c2.mapping() == [("a", "a"), ("b", "b"), ("c", "c"), ("d", "a")] assert c.columns == ["a", "b", "c"] assert c.mapping() == [("a", "a"), ("b", "b"), ("c", "c")] c2 = c.add("a", "b") assert c2.columns == ["a", "b", "c"] assert c2.mapping() == [("a", "b"), ("b", "b"), ("c", "c")] assert c.columns == ["a", "b", "c"] assert c.mapping() == [("a", "a"), ("b", "b"), ("c", "c")] ================================================ FILE: tests/unit/test_mapping.py ================================================ from datetime import timedelta import numpy as np import pandas as pd import pytest from dask_sql._datafusion_lib import SqlTypeName from dask_sql.mappings import python_to_sql_type, similar_type, sql_to_python_value def test_python_to_sql(): assert str(python_to_sql_type(np.dtype("int32"))) == "INTEGER" assert str(python_to_sql_type(np.dtype(">M8[ns]"))) == "TIMESTAMP" assert ( str(python_to_sql_type(pd.DatetimeTZDtype(unit="ns", tz="UTC"))) == "TIMESTAMP_WITH_LOCAL_TIME_ZONE" ) @pytest.mark.gpu def test_python_decimal_to_sql(): import cudf assert str(python_to_sql_type(cudf.Decimal64Dtype(12, 3))) == "DECIMAL" assert str(python_to_sql_type(cudf.Decimal128Dtype(32, 12))) == "DECIMAL" assert str(python_to_sql_type(cudf.Decimal32Dtype(5, -2))) == "DECIMAL" def test_sql_to_python(): assert sql_to_python_value(SqlTypeName.VARCHAR, "test 123") == "test 123" assert type(sql_to_python_value(SqlTypeName.BIGINT, 653)) == np.int64 assert sql_to_python_value(SqlTypeName.BIGINT, 653) == 653 assert sql_to_python_value(SqlTypeName.INTERVAL, 4) == timedelta(microseconds=4000) def test_python_to_sql_to_python(): assert ( type( sql_to_python_value(python_to_sql_type(np.dtype("int64")).getSqlType(), 54) ) == np.int64 ) def test_similar_type(): assert similar_type(np.int64, np.int32) assert similar_type(pd.Int64Dtype(), np.int32) assert not similar_type(np.uint32, np.int32) assert similar_type(np.float32, np.float64) assert similar_type(object, str) ================================================ FILE: tests/unit/test_ml_utils.py ================================================ # Copyright 2017, Dask developers # Dask-ML project - https://github.com/dask/dask-ml from collections.abc import Sequence import dask import dask.array as da import dask.dataframe as dd import numpy as np import pandas as pd import pytest from dask.array.utils import assert_eq as assert_eq_ar from dask.dataframe.utils import assert_eq as assert_eq_df from sklearn.base import clone from sklearn.decomposition import PCA from sklearn.ensemble import GradientBoostingClassifier from sklearn.linear_model import LogisticRegression, SGDClassifier from dask_sql.physical.rel.custom.wrappers import Incremental, ParallelPostFit @pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)]) def test_ml_class_mappings(gpu): from dask_sql.physical.utils.ml_classes import get_cpu_classes, get_gpu_classes from dask_sql.utils import import_class try: import lightgbm import xgboost except KeyError: lightgbm = None xgboost = None classes_dict = get_gpu_classes() if gpu else get_cpu_classes() for key in classes_dict: if not ("XGB" in key and xgboost is None) and not ( "LGBM" in key and lightgbm is None ): import_class(classes_dict[key]) def _check_axis_partitioning(chunks, n_features): c = chunks[1][0] if c != n_features: msg = ( "Can only generate arrays partitioned along the " "first axis. Specifying a larger chunksize for " "the second axis.\n\n\tchunk size: {}\n" "\tn_features: {}".format(c, n_features) ) raise ValueError(msg) def check_random_state(random_state): if random_state is None: return da.random.RandomState() # elif isinstance(random_state, Integral): # return da.random.RandomState(random_state) elif isinstance(random_state, np.random.RandomState): return da.random.RandomState(random_state.randint()) elif isinstance(random_state, da.random.RandomState): return random_state else: raise TypeError(f"Unexpected type '{type(random_state)}'") def make_classification( n_samples=100, n_features=20, n_informative=2, n_classes=2, scale=1.0, random_state=None, chunks=None, ): chunks = da.core.normalize_chunks(chunks, (n_samples, n_features)) _check_axis_partitioning(chunks, n_features) if n_classes != 2: raise NotImplementedError("n_classes != 2 is not yet supported.") rng = check_random_state(random_state) X = rng.normal(0, 1, size=(n_samples, n_features), chunks=chunks) informative_idx = rng.choice(n_features, n_informative, chunks=n_informative) beta = (rng.random(n_features, chunks=n_features) - 1) * scale informative_idx, beta = dask.compute( informative_idx, beta, scheduler="single-threaded" ) z0 = X[:, informative_idx].dot(beta[informative_idx]) y = rng.random(z0.shape, chunks=chunks[0]) < 1 / (1 + da.exp(-z0)) y = y.astype(int) return X, y def _assert_eq(l, r, name=None, **kwargs): array_types = (np.ndarray, da.Array) frame_types = (pd.core.generic.NDFrame, dd.DataFrame) if isinstance(l, array_types): assert_eq_ar(l, r, **kwargs) elif isinstance(l, frame_types): assert_eq_df(l, r, **kwargs) elif isinstance(l, Sequence) and any( isinstance(x, array_types + frame_types) for x in l ): for a, b in zip(l, r): _assert_eq(a, b, **kwargs) elif np.isscalar(r) and np.isnan(r): assert np.isnan(l), (name, l, r) else: assert l == r, (name, l, r) def assert_estimator_equal(left, right, exclude=None, **kwargs): """Check that two Estimators are equal Parameters ---------- left, right : Estimators exclude : str or sequence of str attributes to skip in the check kwargs : dict Passed through to the dask `assert_eq` method. """ left_attrs = [x for x in dir(left) if x.endswith("_") and not x.startswith("_")] right_attrs = [x for x in dir(right) if x.endswith("_") and not x.startswith("_")] if exclude is None: exclude = set() elif isinstance(exclude, str): exclude = {exclude} else: exclude = set(exclude) left_attrs2 = set(left_attrs) - exclude right_attrs2 = set(right_attrs) - exclude assert left_attrs2 == right_attrs2, left_attrs2 ^ right_attrs2 for attr in left_attrs2: l = getattr(left, attr) r = getattr(right, attr) _assert_eq(l, r, name=attr, **kwargs) def test_parallelpostfit_basic(): clf = ParallelPostFit(GradientBoostingClassifier()) X, y = make_classification(n_samples=1000, chunks=100) X_, y_ = dask.compute(X, y) clf.fit(X_, y_) assert isinstance(clf.predict(X), da.Array) assert isinstance(clf.predict_proba(X), da.Array) result = clf.score(X, y) expected = clf.estimator.score(X_, y_) assert result == expected @pytest.mark.parametrize("kind", ["numpy", "dask.dataframe", "dask.array"]) def test_predict(kind): X, y = make_classification(chunks=100) if kind == "numpy": X, y = dask.compute(X, y) elif kind == "dask.dataframe": X = dd.from_dask_array(X) y = dd.from_dask_array(y) base = LogisticRegression(random_state=0, n_jobs=1, solver="lbfgs") wrap = ParallelPostFit( LogisticRegression(random_state=0, n_jobs=1, solver="lbfgs"), ) base.fit(*dask.compute(X, y)) wrap.fit(*dask.compute(X, y)) assert_estimator_equal(wrap.estimator, base) result = wrap.predict(X) expected = base.predict(X) assert_eq_ar(result, expected) result = wrap.predict_proba(X) expected = base.predict_proba(X) assert_eq_ar(result, expected) result = wrap.predict_log_proba(X) expected = base.predict_log_proba(X) assert_eq_ar(result, expected) @pytest.mark.parametrize("kind", ["numpy", "dask.dataframe", "dask.array"]) def test_transform(kind): X, y = make_classification(chunks=100) if kind == "numpy": X, y = dask.compute(X, y) elif kind == "dask.dataframe": X = dd.from_dask_array(X) y = dd.from_dask_array(y) base = PCA(random_state=0) wrap = ParallelPostFit(PCA(random_state=0)) base.fit(*dask.compute(X, y)) wrap.fit(*dask.compute(X, y)) assert_estimator_equal(wrap.estimator, base) result = base.transform(*dask.compute(X)) expected = wrap.transform(X) assert_eq_ar(result, expected) @pytest.mark.parametrize("dataframes", [False, True]) def test_incremental_basic(dataframes): # Create observations that we know linear models can recover n, d = 100, 3 rng = da.random.RandomState(42) X = rng.normal(size=(n, d), chunks=30) coef_star = rng.uniform(size=d, chunks=d) y = da.sign(X.dot(coef_star)) y = (y + 1) / 2 if dataframes: X = dd.from_array(X) y = dd.from_array(y) est1 = SGDClassifier(random_state=0, tol=1e-3, average=True) est2 = clone(est1) clf = Incremental(est1, random_state=0) result = clf.fit(X, y, classes=[0, 1]) assert result is clf # est2 is a sklearn optimizer; this is just a benchmark if dataframes: X = X.to_dask_array(lengths=True) y = y.to_dask_array(lengths=True) for slice_ in da.core.slices_from_chunks(X.chunks): est2.partial_fit(X[slice_].compute(), y[slice_[0]].compute(), classes=[0, 1]) assert isinstance(result.estimator_.coef_, np.ndarray) rel_error = np.linalg.norm(clf.coef_ - est2.coef_) rel_error /= np.linalg.norm(clf.coef_) assert rel_error < 0.9 assert set(dir(clf.estimator_)) == set(dir(est2)) # Predict result = clf.predict(X) expected = est2.predict(X) assert isinstance(result, da.Array) if dataframes: # Compute is needed because chunk sizes of this array are unknown result = result.compute() rel_error = np.linalg.norm(result - expected) rel_error /= np.linalg.norm(expected) assert rel_error < 0.3 # score result = clf.score(X, y) expected = est2.score(*dask.compute(X, y)) assert abs(result - expected) < 0.1 clf = Incremental(SGDClassifier(random_state=0, tol=1e-3, average=True)) clf.partial_fit(X, y, classes=[0, 1]) assert set(dir(clf.estimator_)) == set(dir(est2)) ================================================ FILE: tests/unit/test_queries.py ================================================ import os import pytest XFAIL_QUERIES = ( 5, 8, 10, 14, 16, 18, 22, 23, 24, 27, 28, 35, 36, 39, 41, 44, 47, 49, 51, 57, 62, 64, # FIXME: failing after cudf#14167 and #14079 67, 69, 70, 72, 77, 80, 86, 88, 89, 92, 94, 99, ) QUERIES = [ pytest.param(f"q{i}.sql", marks=pytest.mark.xfail if i in XFAIL_QUERIES else ()) for i in range(1, 100) ] @pytest.fixture(scope="module") def c(data_dir): # Lazy import, otherwise the pytest framework has problems from dask_sql.context import Context c = Context() if not data_dir: data_dir = f"{os.path.dirname(__file__)}/data/" for table_name in os.listdir(data_dir): c.create_table( table_name, data_dir + "/" + table_name, format="parquet", gpu=False, ) yield c @pytest.fixture(scope="module") def gpu_c(data_dir): pytest.importorskip("dask_cudf") # Lazy import, otherwise the pytest framework has problems from dask_sql.context import Context c = Context() if not data_dir: data_dir = f"{os.path.dirname(__file__)}/data/" for table_name in os.listdir(data_dir): c.create_table( table_name, data_dir + "/" + table_name, format="parquet", gpu=True, ) yield c @pytest.mark.queries @pytest.mark.parametrize("query", QUERIES) def test_query(c, client, query, queries_dir): if not queries_dir: queries_dir = f"{os.path.dirname(__file__)}/queries/" with open(queries_dir + "/" + query) as f: sql = f.read() res = c.sql(sql) res.compute(scheduler=client) @pytest.mark.gpu @pytest.mark.queries @pytest.mark.parametrize("query", QUERIES) def test_gpu_query(gpu_c, gpu_client, query, queries_dir): if not queries_dir: queries_dir = f"{os.path.dirname(__file__)}/queries/" with open(queries_dir + "/" + query) as f: sql = f.read() res = gpu_c.sql(sql) res.compute(scheduler=gpu_client) ================================================ FILE: tests/unit/test_statistics.py ================================================ import dask.dataframe as dd import pandas as pd import pytest from dask_sql import Context from dask_sql.datacontainer import Statistics from dask_sql.physical.utils.statistics import parquet_statistics from tests.utils import skipif_dask_expr_enabled # TODO: add support for parquet statistics with dask-expr pytestmark = skipif_dask_expr_enabled( reason="Parquet statistics not yet supported with dask-expr" ) @pytest.mark.parametrize("parallel", [None, False, 2]) def test_parquet_statistics(parquet_ddf, parallel): # Check simple num-rows statistics stats = parquet_statistics(parquet_ddf, parallel=parallel) stats_df = pd.DataFrame(stats) num_rows = stats_df["num-rows"].sum() assert len(stats_df) == parquet_ddf.npartitions assert num_rows == len(parquet_ddf) # Check simple column statistics stats = parquet_statistics(parquet_ddf, columns=["b"], parallel=parallel) b_stats = [ { "min": stat["columns"][0]["min"], "max": stat["columns"][0]["max"], } for stat in stats ] b_stats_df = pd.DataFrame(b_stats) assert b_stats_df["min"].min() == parquet_ddf["b"].min().compute() assert b_stats_df["max"].max() == parquet_ddf["b"].max().compute() def test_parquet_statistics_bad_args(parquet_ddf): # Check "bad" input arguments to parquet_statistics # ddf argument must be a Dask-DataFrame object pdf = pd.DataFrame({"a": range(10)}) with pytest.raises(ValueError, match="Expected Dask DataFrame"): parquet_statistics(pdf) # Return should be None if parquet statistics # cannot be extracted from the provided collection ddf = dd.from_pandas(pdf, npartitions=2) assert parquet_statistics(ddf) is None # Clear error should be raised when columns is not # a list containing a subset of columns from ddf with pytest.raises(ValueError, match="Expected columns to be a list"): parquet_statistics(parquet_ddf, columns="bad") with pytest.raises(ValueError, match="must be a subset"): parquet_statistics(parquet_ddf, columns=["bad"]) def test_dc_statistics(parquet_ddf): c = Context() c.create_table("df", parquet_ddf) assert c.schema["root"].tables["df"].statistics == Statistics(row_count=15) assert c.schema["root"].statistics["df"] == Statistics(row_count=15) ================================================ FILE: tests/unit/test_utils.py ================================================ import pandas as pd import pytest from dask import dataframe as dd from dask.utils_test import hlg_layer from dask_sql.physical.utils.filter import attempt_predicate_pushdown from dask_sql.utils import Pluggable, is_frame from tests.utils import skipif_dask_expr_enabled def test_is_frame_for_frame(): df = dd.from_pandas(pd.DataFrame({"a": [1]}), npartitions=1) assert is_frame(df) def test_is_frame_for_none(): assert not is_frame(None) def test_is_frame_for_number(): assert not is_frame(3) assert not is_frame(3.5) class PluginTest1(Pluggable): pass class PluginTest2(Pluggable): pass def test_add_plugin(): PluginTest1.add_plugin("some_key", "value") assert PluginTest1.get_plugin("some_key") == "value" assert PluginTest1().get_plugin("some_key") == "value" with pytest.raises(KeyError): PluginTest2.get_plugin("some_key") def test_overwrite(): PluginTest1.add_plugin("some_key", "value") assert PluginTest1.get_plugin("some_key") == "value" assert PluginTest1().get_plugin("some_key") == "value" PluginTest1.add_plugin("some_key", "value_2") assert PluginTest1.get_plugin("some_key") == "value_2" assert PluginTest1().get_plugin("some_key") == "value_2" PluginTest1.add_plugin("some_key", "value_3", replace=False) assert PluginTest1.get_plugin("some_key") == "value_2" assert PluginTest1().get_plugin("some_key") == "value_2" @skipif_dask_expr_enabled() def test_predicate_pushdown_simple(parquet_ddf): filtered_df = parquet_ddf[parquet_ddf["a"] > 1] pushdown_df = attempt_predicate_pushdown(filtered_df) got_filters = hlg_layer(pushdown_df.dask, "read-parquet").creation_info["kwargs"][ "filters" ] got_filters = frozenset(frozenset(v) for v in got_filters) expected_filters = [[("a", ">", 1)]] expected_filters = frozenset(frozenset(v) for v in expected_filters) assert got_filters == expected_filters @skipif_dask_expr_enabled() def test_predicate_pushdown_logical(parquet_ddf): filtered_df = parquet_ddf[ (parquet_ddf["a"] > 1) & (parquet_ddf["b"] < 2) | (parquet_ddf["a"] == -1) ] pushdown_df = attempt_predicate_pushdown(filtered_df) got_filters = hlg_layer(pushdown_df.dask, "read-parquet").creation_info["kwargs"][ "filters" ] got_filters = frozenset(frozenset(v) for v in got_filters) expected_filters = [[("a", ">", 1), ("b", "<", 2)], [("a", "==", -1)]] expected_filters = frozenset(frozenset(v) for v in expected_filters) assert got_filters == expected_filters @skipif_dask_expr_enabled() def test_predicate_pushdown_in(parquet_ddf): filtered_df = parquet_ddf[ (parquet_ddf["a"] > 1) & (parquet_ddf["b"] < 2) | (parquet_ddf["a"] == -1) & parquet_ddf["c"].isin(("A", "B", "C")) | ~parquet_ddf["b"].isin((5, 6, 7)) ] pushdown_df = attempt_predicate_pushdown(filtered_df) got_filters = hlg_layer(pushdown_df.dask, "read-parquet").creation_info["kwargs"][ "filters" ] got_filters = frozenset(frozenset(v) for v in got_filters) expected_filters = [ [("b", "<", 2), ("a", ">", 1)], [("a", "==", -1), ("c", "in", ("A", "B", "C"))], [("b", "not in", (5, 6, 7))], ] expected_filters = frozenset(frozenset(v) for v in expected_filters) assert got_filters == expected_filters @skipif_dask_expr_enabled() def test_predicate_pushdown_isna(parquet_ddf): filtered_df = parquet_ddf[ (parquet_ddf["a"] > 1) & (parquet_ddf["b"] < 2) | (parquet_ddf["a"] == -1) & ~parquet_ddf["c"].isna() | parquet_ddf["b"].isna() ] pushdown_df = attempt_predicate_pushdown(filtered_df) got_filters = hlg_layer(pushdown_df.dask, "read-parquet").creation_info["kwargs"][ "filters" ] got_filters = frozenset(frozenset(v) for v in got_filters) expected_filters = [ [("b", "<", 2), ("a", ">", 1)], [("a", "==", -1), ("c", "is not", None)], [("b", "is", None)], ] expected_filters = frozenset(frozenset(v) for v in expected_filters) assert got_filters == expected_filters @skipif_dask_expr_enabled() def test_predicate_pushdown_add_filters(parquet_ddf): filtered_df = parquet_ddf[(parquet_ddf["a"] > 1) | (parquet_ddf["a"] == -1)] pushdown_df = attempt_predicate_pushdown( filtered_df, add_filters=("b", "<", 2), ) got_filters = hlg_layer(pushdown_df.dask, "read-parquet").creation_info["kwargs"][ "filters" ] got_filters = frozenset(frozenset(v) for v in got_filters) expected_filters = [ [("a", ">", 1), ("b", "<", 2)], [("a", "==", -1), ("b", "<", 2)], ] expected_filters = frozenset(frozenset(v) for v in expected_filters) assert got_filters == expected_filters @skipif_dask_expr_enabled() def test_predicate_pushdown_add_filters_no_extract(parquet_ddf): filtered_df = parquet_ddf[(parquet_ddf["a"] > 1) | (parquet_ddf["a"] == -1)] pushdown_df = attempt_predicate_pushdown( filtered_df, extract_filters=False, add_filters=("b", "<", 2), ) got_filters = hlg_layer(pushdown_df.dask, "read-parquet").creation_info["kwargs"][ "filters" ] got_filters = frozenset(frozenset(v) for v in got_filters) expected_filters = [[("b", "<", 2)]] expected_filters = frozenset(frozenset(v) for v in expected_filters) assert got_filters == expected_filters @skipif_dask_expr_enabled() def test_predicate_pushdown_add_filters_no_preserve(parquet_ddf): filtered_df = parquet_ddf[(parquet_ddf["a"] > 1) | (parquet_ddf["a"] == -1)] pushdown_df0 = attempt_predicate_pushdown(filtered_df) pushdown_df = attempt_predicate_pushdown( pushdown_df0, preserve_filters=False, extract_filters=False, add_filters=("b", "<", 2), ) got_filters = hlg_layer(pushdown_df.dask, "read-parquet").creation_info["kwargs"][ "filters" ] got_filters = frozenset(frozenset(v) for v in got_filters) expected_filters = [[("b", "<", 2)]] expected_filters = frozenset(frozenset(v) for v in expected_filters) assert got_filters == expected_filters ================================================ FILE: tests/utils.py ================================================ import os import pytest from dask.dataframe import _dask_expr_enabled from dask.dataframe.utils import assert_eq as _assert_eq # use distributed client for testing if it's available scheduler = ( "distributed" if os.getenv("DASK_SQL_DISTRIBUTED_TESTS", "False").lower() in ("true", "1") else "sync" ) def assert_eq(*args, **kwargs): kwargs.setdefault("scheduler", scheduler) return _assert_eq(*args, **kwargs) def convert_nullable_columns(df): """ Convert certain nullable columns in `df` to non-nullable columns when trying to handle np.NaN and pd.NA would otherwise cause issues. """ dtypes_mapping = { "Int64": "float64", "Float64": "float64", "boolean": "float64", } for dtype in dtypes_mapping: selected_cols = df.select_dtypes(include=[dtype]).columns.tolist() if selected_cols: df[selected_cols] = df[selected_cols].astype(dtypes_mapping[dtype]) return df def skipif_dask_expr_enabled(reason=None): """ Skip the test if dask-expr is enabled """ # most common reason for skipping if reason is None: reason = "Predicate pushdown & column projection should be handled implicitly by dask-expr" return pytest.mark.skipif( _dask_expr_enabled(), reason=reason, )