[
  {
    "path": ".cargo/config.toml",
    "content": "[target.x86_64-apple-darwin]\nrustflags = [\n  \"-C\", \"link-arg=-undefined\",\n  \"-C\", \"link-arg=dynamic_lookup\",\n]\n\n[target.aarch64-apple-darwin]\nrustflags = [\n  \"-C\", \"link-arg=-undefined\",\n  \"-C\", \"link-arg=dynamic_lookup\",\n]\n"
  },
  {
    "path": ".coveragerc",
    "content": "[run]\nomit = tests/*\nbranch = True\n\n[report]\n# Regexes for lines to exclude from consideration\nexclude_lines =\n    # Have to re-enable the standard pragma\n    pragma: no cover\n\n    # Don't complain about missing debug-only code:\n    def __repr__\n\n    # Don't complain if tests don't hit defensive assertion code:\n    raise AssertionError\n    raise NotImplementedError\n\n    # Don't complain if non-runnable code isn't run:\n    if __name__ == .__main__.:\n"
  },
  {
    "path": ".dockerignore",
    "content": "node_modules\n.next\n"
  },
  {
    "path": ".github/CODEOWNERS",
    "content": "# global codeowners\n*              @ayushdg @charlesbluca @galipremsagar\n\n# rust codeowners\n.cargo/        @ayushdg @charlesbluca @galipremsagar @jdye64\nsrc/           @ayushdg @charlesbluca @galipremsagar @jdye64\nCargo.toml     @ayushdg @charlesbluca @galipremsagar @jdye64\nCargo.lock     @ayushdg @charlesbluca @galipremsagar @jdye64\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/bug_report.md",
    "content": "---\nname: Bug report\nabout: Create a bug report to help us improve dask-sql\ntitle: \"[BUG]\"\nlabels: \"bug, needs triage\"\nassignees: ''\n\n---\n\n<!-- Please include a self-contained copy-pastable example that generates the issue if possible.\n\nPlease be concise with code posted. See guidelines below on how to provide a good bug report:\n\n- Craft Minimal Bug Reports http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports\n- Minimal Complete Verifiable Examples https://stackoverflow.com/help/mcve\n\nBug reports that follow these guidelines are easier to diagnose, and so are often handled much more quickly.\n-->\n\n**What happened**:\n\n**What you expected to happen**:\n\n**Minimal Complete Verifiable Example**:\n\n```python\n# Put your MCVE code here\n```\n\n**Anything else we need to know?**:\n\n**Environment**:\n\n- dask-sql version:\n- Python version:\n- Operating System:\n- Install method (conda, pip, source):\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/documentation-request.md",
    "content": "---\nname: Documentation request\nabout: Report incorrect or needed documentation\ntitle: \"[DOC]\"\nlabels: \"documentation\"\nassignees: ''\n\n---\n\n## Report incorrect documentation\n\n**Location of incorrect documentation**\nProvide links and line numbers if applicable.\n\n**Describe the problems or issues found in the documentation**\nA clear and concise description of what you found to be incorrect.\n\n**Steps taken to verify documentation is incorrect**\nList any steps you have taken:\n\n**Suggested fix for documentation**\nDetail proposed changes to fix the documentation if you have any.\n\n---\n\n## Report needed documentation\n\n**Report needed documentation**\nA clear and concise description of what documentation you believe it is needed and why.\n\n**Describe the documentation you'd like**\nA clear and concise description of what you want to happen.\n\n**Steps taken to search for needed documentation**\nList any steps you have taken:\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature_request.md",
    "content": "---\nname: Feature request\nabout: Suggest an idea for dask-sql\ntitle: \"[ENH]\"\nlabels: \"enhancement, needs triage\"\nassignees: ''\n\n---\n\n**Is your feature request related to a problem? Please describe.**\nA clear and concise description of what the problem is. Ex. I wish I could use dask-sql to do [...]\n\n**Describe the solution you'd like**\nA clear and concise description of what you want to happen.\n\n**Describe alternatives you've considered**\nA clear and concise description of any alternative solutions or features you've considered.\n\n**Additional context**\nAdd any other context, code examples, or references to existing implementations about the feature request here.\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/submit-question.md",
    "content": "---\nname: Submit question\nabout: Ask a general question about dask-sql\ntitle: \"[QST]\"\nlabels: \"question\"\nassignees: ''\n\n---\n\n**What is your question?**\n"
  },
  {
    "path": ".github/dependabot.yml",
    "content": "version: 2\nupdates:\n  - package-ecosystem: \"cargo\"\n    directory: \"/\"\n    schedule:\n      interval: \"daily\"\n    ignore:\n      # arrow and datafusion are bumped manually\n      - dependency-name: \"arrow\"\n        update-types: [\"version-update:semver-major\"]\n      - dependency-name: \"datafusion\"\n        update-types: [\"version-update:semver-major\"]\n      - dependency-name: \"datafusion-*\"\n        update-types: [\"version-update:semver-major\"]\n  - package-ecosystem: \"github-actions\"\n    directory: \"/\"\n    schedule:\n      # Check for updates to GitHub Actions every weekday\n      interval: \"weekly\"\n    ignore:\n      # prefer updating cibuildwheel manually as needed\n      - dependency-name: \"pypa/cibuildwheel\"\n"
  },
  {
    "path": ".github/workflows/conda.yml",
    "content": "name: Build conda nightly\non:\n  push:\n    branches:\n      - main\n  pull_request:\n    paths:\n      - Cargo.toml\n      - Cargo.lock\n      - pyproject.toml\n      - continuous_integration/recipe/**\n      - .github/workflows/conda.yml\n  schedule:\n    - cron: '0 0 * * 0'\n\n# When this workflow is queued, automatically cancel any previous running\n# or pending jobs from the same branch\nconcurrency:\n  group: conda-${{ github.head_ref }}\n  cancel-in-progress: true\n\n# Required shell entrypoint to have properly activated conda environments\ndefaults:\n  run:\n    shell: bash -l {0}\n\njobs:\n  conda:\n    name: \"Build conda nightlies (python: ${{ matrix.python }}, arch: ${{ matrix.arch }})\"\n    runs-on: ubuntu-latest\n    strategy:\n      fail-fast: false\n      matrix:\n        python: [\"3.9\", \"3.10\", \"3.11\", \"3.12\"]\n        arch: [\"linux-64\", \"linux-aarch64\"]\n    steps:\n      - name: Manage disk space\n        if: matrix.arch == 'linux-aarch64'\n        run: |\n          sudo mkdir -p /opt/empty_dir || true\n          for d in \\\n                    /opt/ghc \\\n                    /opt/hostedtoolcache \\\n                    /usr/lib/jvm \\\n                    /usr/local/.ghcup \\\n                    /usr/local/lib/android \\\n                    /usr/local/share/powershell \\\n                    /usr/share/dotnet \\\n                    /usr/share/swift \\\n                    ; do\n            sudo rsync --stats -a --delete /opt/empty_dir/ $d || true\n          done\n          sudo apt-get purge -y -f firefox \\\n                                    google-chrome-stable \\\n                                    microsoft-edge-stable\n          sudo apt-get autoremove -y >& /dev/null\n          sudo apt-get autoclean -y >& /dev/null\n          sudo docker image prune --all --force\n          df -h\n      - name: Create swapfile\n        if: matrix.arch == 'linux-aarch64'\n        run: |\n          sudo fallocate -l 10GiB /swapfile || true\n          sudo chmod 600 /swapfile || true\n          sudo mkswap /swapfile || true\n          sudo swapon /swapfile || true\n      - uses: actions/checkout@v4\n        with:\n          fetch-depth: 0\n      - name: Set up Python\n        uses: conda-incubator/setup-miniconda@v2.3.0\n        with:\n          miniforge-variant: Mambaforge\n          use-mamba: true\n          python-version: \"3.9\"\n          channel-priority: strict\n      - name: Install dependencies\n        run: |\n          mamba install -c conda-forge \"boa<0.17\" \"conda-build<24.1\" conda-verify\n\n          which python\n          pip list\n          mamba list\n      - name: Build conda packages\n        run: |\n          # suffix for nightly package versions\n          export VERSION_SUFFIX=a`date +%y%m%d`\n\n          conda mambabuild continuous_integration/recipe \\\n                           --python ${{ matrix.python }} \\\n                           --variants \"{target_platform: [${{ matrix.arch }}]}\" \\\n                           --error-overlinking \\\n                           --no-test \\\n                           --no-anaconda-upload \\\n                           --output-folder packages\n      - name: Test conda packages\n        if: matrix.arch == 'linux-64'  # can only test native platform packages\n        run: |\n          conda mambabuild --test packages/${{ matrix.arch }}/*.tar.bz2\n      - name: Upload conda packages as artifacts\n        uses: actions/upload-artifact@v3\n        with:\n          name: \"conda nightlies (python - ${{ matrix.python }}, arch - ${{ matrix.arch }})\"\n          # need to install all conda channel metadata to properly install locally\n          path: packages/\n      - name: Upload conda packages to Anaconda\n        if: |\n          github.event_name == 'push'\n          && github.repository == 'dask-contrib/dask-sql'\n        env:\n          ANACONDA_API_TOKEN: ${{ secrets.DASK_CONDA_TOKEN }}\n        run: |\n          # install anaconda for upload\n          mamba install -c conda-forge anaconda-client\n\n          anaconda upload --label dev packages/${{ matrix.arch }}/*.tar.bz2\n"
  },
  {
    "path": ".github/workflows/docker.yml",
    "content": "name: Build Docker image\n\non:\n  release:\n    types: [created]\n  push:\n    branches:\n      - main\n  pull_request:\n    paths:\n      - Cargo.toml\n      - Cargo.lock\n      - pyproject.toml\n      - continuous_integration/docker/**\n      - .github/workflows/docker.yml\n\n# When this workflow is queued, automatically cancel any previous running\n# or pending jobs from the same branch\nconcurrency:\n  group: docker-${{ github.ref }}\n  cancel-in-progress: true\n\njobs:\n  push_to_registry:\n    name: Push Docker image to Docker Hub\n    runs-on: ubuntu-latest\n    env:\n      DOCKER_PUSH: ${{ contains(fromJSON('[\"push\", \"release\"]'), github.event_name) && github.repository == 'dask-contrib/dask-sql' }}\n    strategy:\n      fail-fast: false\n      matrix:\n        platform: [\"linux/amd64\", \"linux/arm64\", \"linux/386\"]\n    steps:\n      - uses: actions/checkout@v4\n      - name: Login to DockerHub\n        if: ${{ fromJSON(env.DOCKER_PUSH) }}\n        uses: docker/login-action@v3\n        with:\n          username: ${{ secrets.DOCKER_USERNAME }}\n          password: ${{ secrets.DOCKER_PASSWORD }}\n      - name: Docker meta for main image\n        id: docker_meta_main\n        uses: crazy-max/ghaction-docker-meta@v5\n        with:\n          images: nbraun/dask-sql\n      - name: Build and push main image\n        uses: docker/build-push-action@v5\n        with:\n          context: .\n          file: ./continuous_integration/docker/main.dockerfile\n          build-args: DOCKER_META_VERSION=${{ steps.docker_meta_main.outputs.version }}\n          platforms: ${{ matrix.platform }}\n          tags: ${{ steps.docker_meta_main.outputs.tags }}\n          labels: ${{ steps.docker_meta_main.outputs.labels }}\n          push: ${{ fromJSON(env.DOCKER_PUSH) }}\n          load: ${{ !fromJSON(env.DOCKER_PUSH) }}\n      - name: Check images\n        run: |\n          df -h\n          docker image ls\n          docker image inspect ${{ steps.docker_meta_main.outputs.tags }}\n      - name: Docker meta for cloud image\n        id: docker_meta_cloud\n        uses: crazy-max/ghaction-docker-meta@v5\n        with:\n          images: nbraun/dask-sql-cloud\n      - name: Build and push cloud image\n        uses: docker/build-push-action@v5\n        with:\n          context: .\n          file: ./continuous_integration/docker/cloud.dockerfile\n          build-args: DOCKER_META_VERSION=${{ steps.docker_meta_main.outputs.version }}\n          platforms: ${{ matrix.platform }}\n          tags: ${{ steps.docker_meta_cloud.outputs.tags }}\n          labels: ${{ steps.docker_meta_cloud.outputs.labels }}\n          push: ${{ fromJSON(env.DOCKER_PUSH) }}\n          load: ${{ !fromJSON(env.DOCKER_PUSH) }}\n"
  },
  {
    "path": ".github/workflows/release.yml",
    "content": "name: Upload Python package\non:\n  release:\n    types: [created]\n  pull_request:\n    paths:\n      - .github/workflows/release.yml\n      - dask_sql/__init__.py\n\n# When this workflow is queued, automatically cancel any previous running\n# or pending jobs from the same branch\nconcurrency:\n  group: release-${{ github.head_ref }}\n  cancel-in-progress: true\n\nenv:\n  upload: ${{ github.event_name == 'release' && github.repository == 'dask-contrib/dask-sql' }}\n\njobs:\n  linux:\n    name: Build and publish wheels for linux ${{ matrix.target }}\n    runs-on: ubuntu-latest\n    strategy:\n      fail-fast: false\n      matrix:\n        target: [x86_64, aarch64]\n    steps:\n      - uses: actions/checkout@v4\n      - uses: actions/setup-python@v4\n        with:\n          python-version: '3.10'\n      - name: Build wheels for x86_64\n        if: matrix.target == 'x86_64'\n        uses: PyO3/maturin-action@v1\n        with:\n          target: ${{ matrix.target }}\n          args: --release --out dist\n          sccache: 'true'\n          manylinux: '2_17'\n      - name: Build wheels for aarch64\n        if: matrix.target == 'aarch64'\n        uses: PyO3/maturin-action@v1\n        with:\n          target: ${{ matrix.target }}\n          args: --release --out dist --zig\n          sccache: 'true'\n          manylinux: '2_17'\n      - name: Check dist files\n        run: |\n          pip install twine\n\n          twine check dist/*\n          ls -lh dist/\n      - name: Upload binary wheels\n        uses: actions/upload-artifact@v3\n        with:\n          name: wheels for linux ${{ matrix.target }}\n          path: dist/*\n      - name: Publish package\n        if: env.upload == 'true'\n        env:\n          TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}\n          TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}\n        run: twine upload dist/*\n\n  windows:\n    name: Build and publish wheels for windows\n    runs-on: windows-latest\n    steps:\n      - uses: actions/checkout@v4\n      - uses: actions/setup-python@v4\n        with:\n          python-version: '3.10'\n          architecture: x64\n      - name: Build wheels\n        uses: PyO3/maturin-action@v1\n        with:\n          target: x64\n          args: --release --out dist\n          sccache: 'true'\n      - name: Check dist files\n        run: |\n          pip install twine\n\n          twine check dist/*\n          ls dist/\n      - name: Upload binary wheels\n        uses: actions/upload-artifact@v3\n        with:\n          name: wheels for windows\n          path: dist/*\n      - name: Publish package\n        if: env.upload == 'true'\n        env:\n          TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}\n          TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}\n        run: twine upload dist/*\n\n  macos:\n    name: Build and publish wheels for macos ${{ matrix.target }}\n    runs-on: macos-latest\n    strategy:\n      fail-fast: false\n      matrix:\n        target: [x86_64, aarch64]\n    steps:\n      - uses: actions/checkout@v4\n      - uses: actions/setup-python@v4\n        with:\n          python-version: '3.10'\n      - name: Build wheels\n        uses: PyO3/maturin-action@v1\n        with:\n          target: ${{ matrix.target }}\n          args: --release --out dist\n          sccache: 'true'\n      - name: Check dist files\n        run: |\n          pip install twine\n\n          twine check dist/*\n          ls -lh dist/\n      - name: Upload binary wheels\n        uses: actions/upload-artifact@v3\n        with:\n          name: wheels for macos ${{ matrix.target }}\n          path: dist/*\n      - name: Publish package\n        if: env.upload == 'true'\n        env:\n          TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}\n          TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}\n        run: twine upload dist/*\n\n  sdist:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v4\n      - name: Build sdist\n        uses: PyO3/maturin-action@v1\n        with:\n          command: sdist\n          args: --out dist\n      - uses: actions/setup-python@v4\n        with:\n          python-version: '3.10'\n      - name: Check dist files\n        run: |\n          pip install twine\n\n          twine check dist/*\n          ls -lh dist/\n      - name: Publish source distribution\n        if: env.upload == 'true'\n        env:\n          TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}\n          TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}\n        run: twine upload dist/*\n"
  },
  {
    "path": ".github/workflows/rust.yml",
    "content": "name: Test Rust package\n\non:\n  # always trigger on PR\n  push:\n    branches:\n      - main\n  pull_request:\n  # manual trigger\n  # https://docs.github.com/en/actions/managing-workflow-runs/manually-running-a-workflow\n  workflow_dispatch:\n\nenv:\n  # Disable full debug symbol generation to speed up CI build and keep memory down\n  # \"1\" means line tables only, which is useful for panic tracebacks.\n  RUSTFLAGS: \"-C debuginfo=1\"\n\njobs:\n  detect-ci-trigger:\n    name: Check for upstream trigger phrase\n    runs-on: ubuntu-latest\n    if: github.repository == 'dask-contrib/dask-sql'\n    outputs:\n      triggered: ${{ steps.detect-trigger.outputs.trigger-found }}\n    steps:\n      - uses: actions/checkout@v4\n        with:\n          fetch-depth: 2\n      - uses: xarray-contrib/ci-trigger@v1.2\n        id: detect-trigger\n        with:\n          keyword: \"[test-df-upstream]\"\n\n  # Check crate compiles\n  linux-build-lib:\n    name: cargo check\n    needs: [detect-ci-trigger]\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v4\n      - uses: actions-rs/toolchain@v1\n        with:\n            toolchain: 1.72\n            default: true\n      - name: Cache Cargo\n        uses: actions/cache@v3\n        with:\n          path: /home/runner/.cargo\n          key: cargo-cache\n      - name: Optionally update upstream dependencies\n        if: needs.detect-ci-trigger.outputs.triggered == 'true'\n        run: |\n          bash continuous_integration/scripts/update-dependencies.sh\n      - name: Check workspace in debug mode\n        run: |\n          cargo check\n      - name: Check workspace in release mode\n        run: |\n          cargo check --release\n\n  # test the crate\n  linux-test:\n    name: cargo test\n    needs: [detect-ci-trigger]\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v4\n        with:\n          submodules: true\n      - uses: actions-rs/toolchain@v1\n        with:\n            toolchain: 1.72\n            default: true\n      - name: Cache Cargo\n        uses: actions/cache@v3\n        with:\n          path: /home/runner/.cargo\n          key: cargo-cache\n      - name: Optionally update upstream dependencies\n        if: needs.detect-ci-trigger.outputs.triggered == 'true'\n        run: |\n          bash continuous_integration/scripts/update-dependencies.sh\n      - name: Run tests\n        run: |\n          cargo test\n"
  },
  {
    "path": ".github/workflows/style.yml",
    "content": "---\nname: Python style check\non: [pull_request]\n\n# When this workflow is queued, automatically cancel any previous running\n# or pending jobs from the same branch\nconcurrency:\n  group: style-${{ github.head_ref }}\n  cancel-in-progress: true\n\njobs:\n  pre-commit:\n    name: Run pre-commit hooks\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v4\n      - uses: actions/setup-python@v4\n      - uses: actions-rs/toolchain@v1\n        with:\n            toolchain: 1.72\n            components: clippy\n            default: true\n      - uses: actions-rs/toolchain@v1\n        with:\n            toolchain: nightly\n            components: rustfmt\n      - uses: pre-commit/action@v3.0.0\n"
  },
  {
    "path": ".github/workflows/test-upstream.yml",
    "content": "name: Nightly upstream testing\non:\n  schedule:\n    - cron: \"0 0 * * *\" # Daily “At 00:00” UTC\n  workflow_dispatch: # allows you to trigger the workflow run manually\n\n# Required shell entrypoint to have properly activated conda environments\ndefaults:\n  run:\n    shell: bash -l {0}\n\njobs:\n  test-dev:\n    name: \"Test upstream dev (${{ matrix.os }}, python: ${{ matrix.python }}, distributed: ${{ matrix.distributed }}, query-planning: ${{ matrix.query-planning }})\"\n    runs-on: ${{ matrix.os }}\n    env:\n      CONDA_FILE: continuous_integration/environment-${{ matrix.python }}.yaml\n      DASK_SQL_DISTRIBUTED_TESTS: ${{ matrix.distributed }}\n      DASK_DATAFRAME__QUERY_PLANNING: ${{ matrix.query-planning }}\n    strategy:\n      fail-fast: false\n      matrix:\n        os: [ubuntu-latest, windows-latest, macos-latest]\n        python: [\"3.9\", \"3.10\", \"3.11\", \"3.12\"]\n        distributed: [false]\n        query-planning: [true]\n        include:\n          # run tests on a distributed client\n          - os: \"ubuntu-latest\"\n            python: \"3.9\"\n            distributed: true\n            query-planning: true\n          - os: \"ubuntu-latest\"\n            python: \"3.11\"\n            distributed: true\n            query-planning: true\n          # run tests with query planning disabled\n          - os: \"ubuntu-latest\"\n            python: \"3.9\"\n            distributed: false\n            query-planning: false\n          - os: \"ubuntu-latest\"\n            python: \"3.11\"\n            distributed: false\n            query-planning: false\n    steps:\n      - uses: actions/checkout@v4\n        with:\n          fetch-depth: 0 # Fetch all history for all branches and tags.\n      - name: Set up Python\n        uses: conda-incubator/setup-miniconda@v2.3.0\n        with:\n          miniforge-variant: Mambaforge\n          use-mamba: true\n          python-version: ${{ matrix.python }}\n          channel-priority: strict\n          activate-environment: dask-sql\n          environment-file: ${{ env.CONDA_FILE }}\n      - uses: actions-rs/toolchain@v1\n        with:\n            toolchain: 1.72\n            default: true\n      - name: Install x86_64-apple-darwin target\n        if: matrix.os == 'macos-latest'\n        run: rustup target add x86_64-apple-darwin\n      - name: Build the Rust DataFusion bindings\n        run: |\n          maturin develop\n      - name: Install hive testing dependencies\n        if: matrix.os == 'ubuntu-latest'\n        run: |\n          docker pull bde2020/hive:2.3.2-postgresql-metastore\n          docker pull bde2020/hive-metastore-postgresql:2.3.0\n      - name: Install upstream dev Dask\n        run: |\n          mamba install --no-channel-priority dask/label/dev::dask\n      - name: Install pytest-reportlog\n        run: |\n          # TODO: add pytest-reportlog to testing environments if we move over to JSONL output\n          mamba install pytest-reportlog\n      - name: Test with pytest\n        id: run_tests\n        run: |\n          pytest --report-log test-${{ matrix.os }}-py${{ matrix.python }}-results.jsonl --cov-report=xml -n auto tests --dist loadfile\n      - name: Upload pytest results for failure\n        if: |\n          always()\n          && steps.run_tests.outcome != 'skipped'\n        uses: actions/upload-artifact@v3\n        with:\n          name: test-${{ matrix.os }}-py${{ matrix.python }}-results\n          path: test-${{ matrix.os }}-py${{ matrix.python }}-results.jsonl\n\n  import-dev:\n    name: \"Test importing with bare requirements and upstream dev (query-planning: ${{ matrix.query-planning }})\"\n    runs-on: ubuntu-latest\n    strategy:\n      fail-fast: false\n      matrix:\n        query-planning: [true, false]\n    steps:\n      - uses: actions/checkout@v4\n      - name: Set up Python\n        uses: conda-incubator/setup-miniconda@v2.3.0\n        with:\n          miniforge-variant: Mambaforge\n          use-mamba: true\n          python-version: \"3.9\"\n          channel-priority: strict\n      - uses: actions-rs/toolchain@v1\n        with:\n            toolchain: 1.72\n            default: true\n      - name: Install dependencies and nothing else\n        run: |\n          pip install -e . -vv\n\n          which python\n          pip list\n          mamba list\n      - name: Install upstream dev Dask\n        run: |\n          python -m pip install git+https://github.com/dask/dask\n          python -m pip install git+https://github.com/dask/dask-expr\n          python -m pip install git+https://github.com/dask/distributed\n      - name: Try to import dask-sql\n        env:\n          DASK_DATAFRAME_QUERY_PLANNING: ${{ matrix.query-planning }}\n        run: |\n          python -c \"import dask_sql; print('ok')\"\n\n  report-failures:\n    name: Open issue for upstream dev failures\n    needs: [test-dev, import-dev]\n    if: |\n      always()\n      && (\n        needs.test-dev.result == 'failure'\n        || needs.import-dev.result == 'failure'\n      )\n      && github.repository == 'dask-contrib/dask-sql'\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v4\n      - uses: actions/download-artifact@v3\n      - name: Prepare logs & issue label\n        run: |\n          # TODO: remove this if xarray-contrib/issue-from-pytest-log no longer needs a log-path\n          if [ -f test-ubuntu-latest-py3.10-results/test-ubuntu-latest-py3.10-results.jsonl ]; then\n              cp test-ubuntu-latest-py3.10-results/test-ubuntu-latest-py3.10-results.jsonl results.jsonl\n          else\n              touch results.jsonl\n          fi\n      - name: Open or update issue on failure\n        uses: xarray-contrib/issue-from-pytest-log@v1.2.6\n        with:\n          log-path: results.jsonl\n          issue-title: ⚠️ Upstream CI failed ⚠️\n          issue-label: upstream\n"
  },
  {
    "path": ".github/workflows/test.yml",
    "content": "name: Test Python package\non:\n  push:\n    branches:\n      - main\n  pull_request:\n\n# When this workflow is queued, automatically cancel any previous running\n# or pending jobs from the same branch\nconcurrency:\n  group: test-${{ github.head_ref }}\n  cancel-in-progress: true\n\n# Required shell entrypoint to have properly activated conda environments\ndefaults:\n  run:\n    shell: bash -l {0}\n\njobs:\n  detect-ci-trigger:\n    name: Check for upstream trigger phrase\n    runs-on: ubuntu-latest\n    if: github.repository == 'dask-contrib/dask-sql'\n    outputs:\n      triggered: ${{ steps.detect-trigger.outputs.trigger-found }}\n    steps:\n      - uses: actions/checkout@v4\n        with:\n          fetch-depth: 2\n      - uses: xarray-contrib/ci-trigger@v1.2\n        id: detect-trigger\n        with:\n          keyword: \"[test-upstream]\"\n\n  test:\n    name: \"Build & Test (${{ matrix.os }}, python: ${{ matrix.python }}, distributed: ${{ matrix.distributed }}, query-planning: ${{ matrix.query-planning }})\"\n    needs: [detect-ci-trigger]\n    runs-on: ${{ matrix.os }}\n    env:\n      CONDA_FILE: continuous_integration/environment-${{ matrix.python }}.yaml\n      DASK_SQL_DISTRIBUTED_TESTS: ${{ matrix.distributed }}\n      DASK_DATAFRAME__QUERY_PLANNING: ${{ matrix.query-planning }}\n    strategy:\n      fail-fast: false\n      matrix:\n        os: [ubuntu-latest, windows-latest, macos-latest]\n        python: [\"3.9\", \"3.10\", \"3.11\", \"3.12\"]\n        distributed: [false]\n        query-planning: [true]\n        include:\n          # run tests on a distributed client\n          - os: \"ubuntu-latest\"\n            python: \"3.9\"\n            distributed: true\n            query-planning: true\n          - os: \"ubuntu-latest\"\n            python: \"3.11\"\n            distributed: true\n            query-planning: true\n          # run tests with query planning disabled\n          - os: \"ubuntu-latest\"\n            python: \"3.9\"\n            distributed: false\n            query-planning: false\n          - os: \"ubuntu-latest\"\n            python: \"3.11\"\n            distributed: false\n            query-planning: false\n    steps:\n      - uses: actions/checkout@v4\n      - name: Set up Python\n        uses: conda-incubator/setup-miniconda@v2.3.0\n        with:\n          miniforge-variant: Mambaforge\n          use-mamba: true\n          python-version: ${{ matrix.python }}\n          channel-priority: strict\n          activate-environment: dask-sql\n          environment-file: ${{ env.CONDA_FILE }}\n          run-post: ${{ matrix.os != 'windows-latest' && 'true' || 'false' }}\n      - uses: actions-rs/toolchain@v1\n        with:\n            toolchain: 1.72\n            default: true\n      - name: Install x86_64-apple-darwin target\n        if: matrix.os == 'macos-latest'\n        run: rustup target add x86_64-apple-darwin\n      - name: Build the Rust DataFusion bindings\n        run: |\n          maturin develop\n      - name: Install hive testing dependencies\n        if: matrix.os == 'ubuntu-latest'\n        run: |\n          docker pull bde2020/hive:2.3.2-postgresql-metastore\n          docker pull bde2020/hive-metastore-postgresql:2.3.0\n      - name: Optionally install upstream dev Dask\n        if: needs.detect-ci-trigger.outputs.triggered == 'true'\n        run: |\n          mamba install --no-channel-priority dask/label/dev::dask\n      - name: Test with pytest\n        run: |\n          pytest --junitxml=junit/test-results.xml --cov-report=xml -n auto tests --dist loadfile\n      - name: Upload pytest test results\n        if: always()\n        uses: actions/upload-artifact@v3\n        with:\n          name: pytest-results\n          path: junit/test-results.xml\n      - name: Upload coverage to Codecov\n        if: github.repository == 'dask-contrib/dask-sql'\n        uses: codecov/codecov-action@v3\n\n  import:\n    name: \"Test importing with bare requirements (query-planning: ${{ matrix.query-planning }})\"\n    needs: [detect-ci-trigger]\n    runs-on: ubuntu-latest\n    strategy:\n      fail-fast: false\n      matrix:\n        query-planning: [true, false]\n    steps:\n      - uses: actions/checkout@v4\n      - name: Set up Python\n        uses: conda-incubator/setup-miniconda@v2.3.0\n        with:\n          miniforge-variant: Mambaforge\n          use-mamba: true\n          python-version: \"3.9\"\n          channel-priority: strict\n      - uses: actions-rs/toolchain@v1\n        with:\n            toolchain: 1.72\n            default: true\n      - name: Install dependencies and nothing else\n        run: |\n          pip install -e . -vv\n\n          which python\n          pip list\n          mamba list\n      - name: Optionally install upstream dev Dask\n        if: needs.detect-ci-trigger.outputs.triggered == 'true'\n        run: |\n          python -m pip install git+https://github.com/dask/dask\n          python -m pip install git+https://github.com/dask/dask-expr\n          python -m pip install git+https://github.com/dask/distributed\n      - name: Try to import dask-sql\n        env:\n          DASK_DATAFRAME_QUERY_PLANNING: ${{ matrix.query-planning }}\n        run: |\n          python -c \"import dask_sql; print('ok')\"\n"
  },
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n*.so\n\n# Unit test / coverage reports\nhtmlcov/\n.coverage\n.coverage.*\n.cache\ncoverage.xml\n*.cover\n.pytest_cache/\n.hypothesis/\n.pytest-html\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# environments\nconda-env\nenv\nvenv\n\n# IDE\n.idea\n.vscode\n*.swp\n\n# project specific\ndask-worker-space/\nnode_modules/\ndocs/source/_build/\ntests/unit/queries\ntests/unit/data\ntarget/*\npackages/*\n\n# Ignore development specific local testing files\ndev_tests\ndev-tests\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "repos:\n  - repo: https://github.com/psf/black\n    rev: 22.10.0\n    hooks:\n      - id: black\n        language_version: python3\n  - repo: https://github.com/PyCQA/flake8\n    rev: 5.0.4\n    hooks:\n      - id: flake8\n        language_version: python3\n  - repo: https://github.com/pycqa/isort\n    rev: 5.12.0\n    hooks:\n      - id: isort\n        args:\n          - \"--profile\"\n          - \"black\"\n  - repo: https://github.com/doublify/pre-commit-rust\n    rev: v1.0\n    hooks:\n      - id: cargo-check\n        args: ['--manifest-path', './Cargo.toml', '--verbose', '--']\n      - id: clippy\n        args: ['--manifest-path', './Cargo.toml', '--verbose', '--', '-D', 'warnings']\n  - repo: https://github.com/pre-commit/pre-commit-hooks\n    rev: v4.2.0\n    hooks:\n      - id: trailing-whitespace\n      - id: end-of-file-fixer\n      - id: check-yaml\n        exclude: ^continuous_integration/recipe/\n      - id: check-added-large-files\n  - repo: local\n    hooks:\n      - id: cargo-fmt\n        name: cargo fmt\n        description: Format files with cargo fmt.\n        entry: cargo +nightly fmt\n        language: system\n        types: [rust]\n        args: ['--manifest-path', './Cargo.toml', '--verbose', '--']\n"
  },
  {
    "path": ".readthedocs.yaml",
    "content": "# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details\nversion: 2\nbuild:\n  os: ubuntu-20.04\n  tools:\n    python: \"mambaforge-4.10\"\n\nsphinx:\n  configuration: docs/source/conf.py\n\nconda:\n  environment: docs/environment.yml\n\npython:\n  install:\n    - method: pip\n      path: .\n"
  },
  {
    "path": "CODE_OF_CONDUCT.md",
    "content": "# Contributor Covenant Code of Conduct\n\n## Our Pledge\n\nIn the interest of fostering an open and welcoming environment, we as\ncontributors and maintainers pledge to making participation in our project and\nour community a harassment-free experience for everyone, regardless of age, body\nsize, disability, ethnicity, sex characteristics, gender identity and expression,\nlevel of experience, education, socio-economic status, nationality, personal\nappearance, race, religion, or sexual identity and orientation.\n\n## Our Standards\n\nExamples of behavior that contributes to creating a positive environment\ninclude:\n\n* Using welcoming and inclusive language\n* Being respectful of differing viewpoints and experiences\n* Gracefully accepting constructive criticism\n* Focusing on what is best for the community\n* Showing empathy towards other community members\n\nExamples of unacceptable behavior by participants include:\n\n* The use of sexualized language or imagery and unwelcome sexual attention or\n advances\n* Trolling, insulting/derogatory comments, and personal or political attacks\n* Public or private harassment\n* Publishing others' private information, such as a physical or electronic\n address, without explicit permission\n* Other conduct which could reasonably be considered inappropriate in a\n professional setting\n\n## Our Responsibilities\n\nProject maintainers are responsible for clarifying the standards of acceptable\nbehavior and are expected to take appropriate and fair corrective action in\nresponse to any instances of unacceptable behavior.\n\nProject maintainers have the right and responsibility to remove, edit, or\nreject comments, commits, code, wiki edits, issues, and other contributions\nthat are not aligned to this Code of Conduct, or to ban temporarily or\npermanently any contributor for other behaviors that they deem inappropriate,\nthreatening, offensive, or harmful.\n\n## Scope\n\nThis Code of Conduct applies both within project spaces and in public spaces\nwhen an individual is representing the project or its community. Examples of\nrepresenting a project or community include using an official project e-mail\naddress, posting via an official social media account, or acting as an appointed\nrepresentative at an online or offline event. Representation of a project may be\nfurther defined and clarified by project maintainers.\n\n## Enforcement\n\nInstances of abusive, harassing, or otherwise unacceptable behavior may be\nreported by contacting the project team at nilslennartbraun@gmail.com. All\ncomplaints will be reviewed and investigated and will result in a response that\nis deemed necessary and appropriate to the circumstances. The project team is\nobligated to maintain confidentiality with regard to the reporter of an incident.\nFurther details of specific enforcement policies may be posted separately.\n\nProject maintainers who do not follow or enforce the Code of Conduct in good\nfaith may face temporary or permanent repercussions as determined by other\nmembers of the project's leadership.\n\n## Attribution\n\nThis Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,\navailable at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html\n\n[homepage]: https://www.contributor-covenant.org\n\nFor answers to common questions about this code of conduct, see\nhttps://www.contributor-covenant.org/faq\n"
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "# Contributing to Dask-SQL\n\n## Environment Setup\n\nThe environment used for development and CI consists of:\n\n- a system installation of [`rustup`](https://rustup.rs/) with:\n    - the latest stable toolchain\n    - the latest nightly `rustfmt`\n- a [conda](https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html) environment containing all required Python packages\n\nOnce `rustup` is installed, ensure that the latest stable toolchain and nightly `rustfmt` are available by running\n\n```\nrustup toolchain install nightly -c rustfmt --profile minimal\nrustup update\n```\n\nTo initialize and activate the conda environment for a given Python version:\n\n```\nconda env create -f dask-sql/continuous_integration/environment-{$PYTHON_VER}.yaml\nconda activate dask-sql\n```\n\n## Rust Developers Guide\n\nDask-SQL utilizes [Apache Arrow Datafusion](https://github.com/apache/arrow-datafusion) for parsing, planning, and optimizing SQL queries. DataFusion is written in Rust and therefore requires some Rust experience to be productive. Luckily, there are tons of great Rust learning resources on the internet. We have listed some of our favorite ones [here](#rust-learning-resources)\n\n### Apache Arrow DataFusion\nThe Dask-SQL Rust codebase makes heavy use [Apache Arrow DataFusion](https://github.com/apache/arrow-datafusion). Contributors should familiarize themselves with the [codebase](https://github.com/apache/arrow-datafusion) and [documentation](https://docs.rs/datafusion/latest/datafusion/).\n\n#### Purpose\nDataFusion provides Dask-SQL with key functionality.\n- Parsing SQL query strings into a `LogicalPlan` datastructure\n- Future integration points with [substrait.io](https://substrait.io/)\n- An optimization framework used as the baseline for creating custom highly efficient `LogicalPlan`s specific to Dask.\n\n### Building\nBuilding the Dask-SQL Rust codebase is a straightforward process. If you create and activate the Dask-SQL Conda environment the Rust compiler and all necessary components will be installed for you during that process and therefore requires no further manual setup.\n\n`maturin` is used by Dask-SQL for building and bundling the resulting Rust binaries. This helps make building and installing the Rust binaries feel much more like a native Python workflow.\n\nMore details about the building setup can be found in [pyproject.toml](pyproject.toml) and [Cargo.toml](Cargo.toml)\n\nNote that while `maturin` is used by CI and should be used during your development cycle, if the need arises to do something more specific that is not yet supported by `maturin` you can opt to use `cargo` directly from the command line.\n\n#### Building with Python\nBuilding Dask-SQL is straightforward with Python. To build run ```pip install .```. This will build both the Rust and Python codebase and install it into your locally activated conda environment; note that if your Rust dependencies have been updated, this command must be rerun to rebuild the Rust codebase.\n\n#### DataFusion Modules\nDataFusion is broken down into a few modules. We consume those modules in our [Cargo.toml](Cargo.toml). The modules that we use currently are\n\n- `datafusion-common` - Datastructures and core logic\n- `datafusion-expr` - Expression based logic and operators\n- `datafusion-sql` - SQL components such as parsing and planning\n- `datafusion-optimizer` - Optimization logic and datastructures for modifying current plans into more efficient ones.\n\n#### Retrieving Upstream Dependencies\nDuring development you might find yourself needing some upstream DataFusion changes not present in the projects current version. Luckily this can easily be achieved by updating [Cargo.toml](Cargo.toml) and changing the `rev` to the SHA of the version you need. Note that the same SHA should be used for all DataFusion modules.\n\n#### Local Documentation\nSometimes when building against the latest Github commits for DataFusion you may find that the features you are consuming do not have their documentation public yet. In this case it can be helpful to build the DataFusion documentation locally so that it can be referenced to assist with development. Here is a rough outline for building that documentation locally.\n\n- clone https://github.com/apache/arrow-datafusion\n- change into the `arrow-datafusion` directory\n- run `cargo doc`\n- navigate to `target/doc/datafusion/all.html` and open in your desired browser\n\n### Datastructures\nWhile working in the Rust codebase there are a few datastructures that you should make yourself familiar with. This section does not aim to verbosely list out all of the datastructure with in the project but rather just the key datastructures that you are likely to encounter while working on almost any feature/issue. The aim is to give you a better overview of the codebase without having to manually dig through the all the source code.\n\n- [`PyLogicalPlan`](src/sql/logical.rs) -> [DataFusion LogicalPlan](https://docs.rs/datafusion/latest/datafusion/logical_plan/enum.LogicalPlan.html)\n    - Often encountered in Python code with variable name `rel`\n    - Python serializable umbrella representation of the entire LogicalPlan that was generated by DataFusion\n    - Provides access to `DaskTable` instances and type information for each table\n    - Access to individual nodes in the logical plan tree. Ex: `TableScan`\n- [`DaskSQLContext`](src/sql.rs)\n    - Analogous to Python `Context`\n    - Contains metadata about the tables, schemas, functions, operators, and configurations that are persent within the current execution context\n    - When adding custom functions/UDFs this is the location that you would register them\n    - Entry point for parsing SQL strings to sql node trees. This is the location Python will begin its interactions with Rust\n- [`PyExpr`](src/expression.rs) -> [DataFusion Expr](https://docs.rs/datafusion/latest/datafusion/prelude/enum.Expr.html)\n    - Arguably where most of your time will be spent\n    - Represents a single node in sql tree. Ex: `avg(age)` from `SELECT avg(age) FROM people`\n    - Is associate with a single `RexType`\n    - Can contain literal values or represent function calls, `avg()` for example\n    - The expressions \"index\" in the tree can be retrieved by calling `PyExpr.index()` on an instance. This is useful when mapping frontend column names in Dask code to backend Dataframe columns\n    - Certain `PyExpr`s contain operands. Ex: `2 + 2` would contain 3 operands. 1) A literal `PyExpr` instance with value 2 2) Another literal `PyExpr` instance with a value of 2. 3) A `+` `PyExpr` representing the addition of the 2 literals.\n- [`DaskSqlOptimizer`](src/sql/optimizer.rs)\n    - Registering location for all Dask-SQL specific logical plan optimizations\n    - Optimizations that are written either custom or use from another source, DataFusion, are registered here in the order they are wished to be executed\n    - Represents functions that modify/convert an original `PyLogicalPlan` into another `PyLogicalPlan` that would be more efficient when running in the underlying Dask framework\n- [`RelDataType`](src/sql/types/rel_data_type.rs)\n    - Not a fan of this name, was chosen to match existing Calcite logic\n    - Represents a \"row\" in a table\n    - Contains a list of \"columns\" that are present in that row\n        - [RelDataTypeField](src/sql/types/rel_data_type_field.rs)\n- [RelDataTypeField](src/sql/types/rel_data_type_field.rs)\n    - Represents an individual column in a table\n    - Contains:\n        - `qualifier` - schema the field belongs to\n        - `name` - name of the column/field\n        - `data_type` - `DaskTypeMap` instance containing information about the SQL type and underlying Arrow DataType\n        - `index` - location of the field in the LogicalPlan\n- [DaskTypeMap](src/sql/types.rs)\n    - Maps a conventional SQL type to an underlying Arrow DataType\n\n\n### Rust Learning Resources\n- [\"The Book\"](https://doc.rust-lang.org/book/)\n- [Lets Get Rusty \"LGR\" YouTube series](https://www.youtube.com/c/LetsGetRusty)\n\n## Documentation TODO\n- [ ] SQL Parsing overview diagram\n- [ ] Architecture diagram\n- [x] Setup dev environment\n- [x] Version of Rust and specs\n- [x] Updating version of datafusion\n- [x] Building\n- [x] Rust learning resources\n- [x] Rust Datastructures local to Dask-SQL\n- [x] Build DataFusion documentation locally\n- [ ] Python & Rust with PyO3\n- [ ] Types mapping, Arrow datatypes\n- [ ] RexTypes explaination, show simple query and show it broken down into its parts in a diagram\n- [ ] Registering tables with DaskSqlContext, also functions\n- [ ] Creating your own optimizer\n- [ ] Simple diagram of PyExpr, showing something like 2+2 but broken down into a tree looking diagram\n"
  },
  {
    "path": "Cargo.toml",
    "content": "[package]\nname = \"dask-sql\"\nrepository = \"https://github.com/dask-contrib/dask-sql\"\nversion = \"2024.5.0\"\ndescription = \"Bindings for DataFusion used by Dask-SQL\"\nreadme = \"README.md\"\nlicense = \"Apache-2.0\"\nedition = \"2021\"\nrust-version = \"1.72\"\ninclude = [\"/src\", \"/dask_sql\", \"/LICENSE.txt\", \"pyproject.toml\", \"Cargo.toml\", \"Cargo.lock\"]\n\n[dependencies]\nasync-trait = \"0.1.78\"\ndatafusion-python = { git = \"https://github.com/apache/arrow-datafusion-python.git\", ref = \"da6c183\" }\nenv_logger = \"0.11\"\nlog = \"^0.4\"\npyo3 = { version = \"0.19.2\", features = [\"extension-module\", \"abi3\", \"abi3-py39\"] }\npyo3-log = \"0.9.0\"\n\n[build-dependencies]\npyo3-build-config = \"0.20.3\"\n\n[lib]\nname = \"dask_sql\"\ncrate-type = [\"cdylib\", \"rlib\"]\n\n[profile.release]\nlto = true\ncodegen-units = 1\n"
  },
  {
    "path": "LICENSE.txt",
    "content": "MIT LICENCE\n\nCopyright (c) 2020 Nils Braun\n\nPermission is hereby granted, free of charge, to any person obtaining a copy of this software and associated\ndocumentation files (the \"Software\"), to deal in the Software without restriction, including without limitation the\nrights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit\npersons to whom the Software is furnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all copies or substantial portions of the\nSoftware.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE\nWARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR\nCOPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR\nOTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.\n"
  },
  {
    "path": "MANIFEST.in",
    "content": "recursive-include dask_sql *.yaml\nrecursive-include dask_planner *\n"
  },
  {
    "path": "README.md",
    "content": "**Dask-SQL is currently not in active maintenance, see [#1344](https://github.com/dask-contrib/dask-sql/issues/1344) for more information**\n\n\n[![Conda](https://img.shields.io/conda/v/conda-forge/dask-sql)](https://anaconda.org/conda-forge/dask-sql)\n[![PyPI](https://img.shields.io/pypi/v/dask-sql?logo=pypi)](https://pypi.python.org/pypi/dask-sql/)\n[![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/dask-contrib/dask-sql/test.yml?branch=main)](https://github.com/dask-contrib/dask-sql/actions/workflows/test.yml?query=branch%3Amain)\n[![Read the Docs](https://img.shields.io/readthedocs/dask-sql)](https://dask-sql.readthedocs.io/en/latest/)\n[![Codecov](https://img.shields.io/codecov/c/github/dask-contrib/dask-sql?logo=codecov)](https://codecov.io/gh/dask-contrib/dask-sql)\n[![GitHub](https://img.shields.io/github/license/dask-contrib/dask-sql)](https://github.com/dask-contrib/dask-sql/blob/main/LICENSE.txt)\n[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/dask-contrib/dask-sql-binder/main?urlpath=lab)\n\n<div align=\"center\">\n    <img src=\"./.github/heart.png\" alt=\"SQL + Python\">\n</div>\n\n`dask-sql` is a distributed SQL query engine in Python.\nIt allows you to query and transform your data using a mixture of\ncommon SQL operations and Python code and also scale up the calculation easily\nif you need it.\n\n* **Combine the power of Python and SQL**: load your data with Python, transform it with SQL, enhance it with Python and query it with SQL - or the other way round.\n  With `dask-sql` you can mix the well known Python dataframe API of `pandas` and `Dask` with common SQL operations, to\n  process your data in exactly the way that is easiest for you.\n* **Infinite Scaling**: using the power of the great `Dask` ecosystem, your computations can scale as you need it - from your laptop to your super cluster - without changing any line of SQL code. From k8s to cloud deployments, from batch systems to YARN - if `Dask` [supports it](https://docs.dask.org/en/latest/setup.html), so will `dask-sql`.\n* **Your data - your queries**: Use Python user-defined functions (UDFs) in SQL without any performance drawback and extend your SQL queries with the large number of Python libraries, e.g. machine learning, different complicated input formats, complex statistics.\n* **Easy to install and maintain**: `dask-sql` is just a pip/conda install away (or a docker run if you prefer).\n* **Use SQL from wherever you like**: `dask-sql` integrates with your jupyter notebook, your normal Python module or can be used as a standalone SQL server from any BI tool. It even integrates natively with [Apache Hue](https://gethue.com/).\n* **GPU Support**: `dask-sql` supports running SQL queries on CUDA-enabled GPUs by utilizing [RAPIDS](https://rapids.ai) libraries like [`cuDF`](https://github.com/rapidsai/cudf), enabling accelerated compute for SQL.\n\nRead more in the [documentation](https://dask-sql.readthedocs.io/en/latest/).\n\n<div align=\"center\">\n    <img src=\"./.github/animation.gif\" alt=\"dask-sql GIF\">\n</div>\n\n---\n\n## Example\n\nFor this example, we use some data loaded from disk and query them with a SQL command from our python code.\nAny pandas or dask dataframe can be used as input and ``dask-sql`` understands a large amount of formats (csv, parquet, json,...) and locations (s3, hdfs, gcs,...).\n\n```python\nimport dask.dataframe as dd\nfrom dask_sql import Context\n\n# Create a context to hold the registered tables\nc = Context()\n\n# Load the data and register it in the context\n# This will give the table a name, that we can use in queries\ndf = dd.read_csv(\"...\")\nc.create_table(\"my_data\", df)\n\n# Now execute a SQL query. The result is again dask dataframe.\nresult = c.sql(\"\"\"\n    SELECT\n        my_data.name,\n        SUM(my_data.x)\n    FROM\n        my_data\n    GROUP BY\n        my_data.name\n\"\"\", return_futures=False)\n\n# Show the result\nprint(result)\n```\n\n## Quickstart\n\nHave a look into the [documentation](https://dask-sql.readthedocs.io/en/latest/) or start the example notebook on [binder](https://mybinder.org/v2/gh/dask-contrib/dask-sql-binder/main?urlpath=lab).\n\n\n> `dask-sql` is currently under development and does so far not understand all SQL commands (but a large fraction).\nWe are actively looking for feedback, improvements and contributors!\n\n## Installation\n\n`dask-sql` can be installed via `conda` (preferred) or `pip` - or in a development environment.\n\n### With `conda`\n\nCreate a new conda environment or use your already present environment:\n\n    conda create -n dask-sql\n    conda activate dask-sql\n\nInstall the package from the `conda-forge` channel:\n\n    conda install dask-sql -c conda-forge\n\n### With `pip`\n\nYou can install the package with\n\n    pip install dask-sql\n\n### For development\n\nIf you want to have the newest (unreleased) `dask-sql` version or if you plan to do development on `dask-sql`, you can also install the package from sources.\n\n    git clone https://github.com/dask-contrib/dask-sql.git\n\nCreate a new conda environment and install the development environment:\n\n    conda env create -f continuous_integration/environment-3.9.yaml\n\nIt is not recommended to use `pip` instead of `conda` for the environment setup.\n\nAfter that, you can install the package in development mode\n\n    pip install -e \".[dev]\"\n\nThe Rust DataFusion bindings are built as part of the `pip install`.\nNote that if changes are made to the Rust source in `src/`, another build must be run to recompile the bindings.\nThis repository uses [pre-commit](https://pre-commit.com/) hooks. To install them, call\n\n    pre-commit install\n\n## Testing\n\nYou can run the tests (after installation) with\n\n    pytest tests\n\nGPU-specific tests require additional dependencies specified in `continuous_integration/gpuci/environment.yaml`.\nThese can be added to the development environment by running\n\n```\nconda env update -n dask-sql -f continuous_integration/gpuci/environment.yaml\n```\n\nAnd GPU-specific tests can be run with\n\n```\npytest tests -m gpu --rungpu\n```\n\n## SQL Server\n\n`dask-sql` comes with a small test implementation for a SQL server.\nInstead of rebuilding a full ODBC driver, we re-use the [presto wire protocol](https://github.com/prestodb/presto/wiki/HTTP-Protocol).\nIt is - so far - only a start of the development and missing important concepts, such as\nauthentication.\n\nYou can test the sql presto server by running (after installation)\n\n    dask-sql-server\n\nor by using the created docker image\n\n    docker run --rm -it -p 8080:8080 nbraun/dask-sql\n\nin one terminal. This will spin up a server on port 8080 (by default)\nthat looks similar to a normal presto database to any presto client.\n\nYou can test this for example with the default [presto client](https://prestosql.io/docs/current/installation/cli.html):\n\n    presto --server localhost:8080\n\nNow you can fire simple SQL queries (as no data is loaded by default):\n\n    => SELECT 1 + 1;\n     EXPR$0\n    --------\n        2\n    (1 row)\n\nYou can find more information in the [documentation](https://dask-sql.readthedocs.io/en/latest/pages/server.html).\n\n## CLI\n\nYou can also run the CLI `dask-sql` for testing out SQL commands quickly:\n\n    dask-sql --load-test-data --startup\n\n    (dask-sql) > SELECT * FROM timeseries LIMIT 10;\n\n## How does it work?\n\nAt the core, `dask-sql` does two things:\n\n- translate the SQL query using [DataFusion](https://arrow.apache.org/datafusion) into a relational algebra, which is represented as a logical query plan - similar to many other SQL engines (Hive, Flink, ...)\n- convert this description of the query into dask API calls (and execute them) - returning a dask dataframe.\n\nFor the first step, Arrow DataFusion needs to know about the columns and types of the dask dataframes, therefore some Rust code to store this information for dask dataframes are defined in `dask_planner`.\nAfter the translation to a relational algebra is done (using `DaskSQLContext.logical_relational_algebra`), the python methods defined in `dask_sql.physical` turn this into a physical dask execution plan by converting each piece of the relational algebra one-by-one.\n"
  },
  {
    "path": "conftest.py",
    "content": "import dask\nimport pytest\n\npytest_plugins = [\"tests.integration.fixtures\"]\n\n\ndef pytest_addoption(parser):\n    parser.addoption(\"--rungpu\", action=\"store_true\", help=\"run tests meant for GPU\")\n    parser.addoption(\"--runqueries\", action=\"store_true\", help=\"run test queries\")\n    parser.addoption(\"--data_dir\", help=\"specify file path to the data\")\n    parser.addoption(\"--queries_dir\", help=\"specify file path to the queries\")\n\n\ndef pytest_runtest_setup(item):\n    # TODO: get pyarrow strings and p2p shuffle working\n    dask.config.set({\"dataframe.convert-string\": False})\n    dask.config.set({\"dataframe.shuffle.method\": \"tasks\"})\n    if \"gpu\" in item.keywords:\n        if not item.config.getoption(\"--rungpu\"):\n            pytest.skip(\"need --rungpu option to run\")\n        # manually enable cudf decimal support\n        dask.config.set({\"sql.mappings.decimal_support\": \"cudf\"})\n    if \"queries\" in item.keywords and not item.config.getoption(\"--runqueries\"):\n        pytest.skip(\"need --runqueries option to run\")\n\n\n@pytest.fixture(scope=\"session\")\ndef data_dir(request):\n    return request.config.getoption(\"--data_dir\")\n\n\n@pytest.fixture(scope=\"session\")\ndef queries_dir(request):\n    return request.config.getoption(\"--queries_dir\")\n"
  },
  {
    "path": "continuous_integration/docker/cloud.dockerfile",
    "content": "ARG DOCKER_META_VERSION\nFROM nbraun/dask-sql:${DOCKER_META_VERSION}\n\nRUN conda config --add channels conda-forge \\\n    && /opt/conda/bin/mamba install --freeze-installed -y \\\n    s3fs \\\n    dask-cloudprovider \\\n    && pip install awscli \\\n    && conda clean -ay\n\nENTRYPOINT [\"tini\", \"-g\", \"--\", \"/usr/bin/prepare.sh\"]\n"
  },
  {
    "path": "continuous_integration/docker/conda.txt",
    "content": "python>=3.9\ndask>=2024.4.1\npandas>=1.4.0\njpype1>=1.0.2\nopenjdk>=8\nmaven>=3.6.0\npytest>=6.0.2\npytest-cov>=2.10.1\npytest-xdist\nmock>=4.0.3\nsphinx>=3.2.1\ntzlocal>=2.1\nfastapi>=0.92.0\nhttpx>=0.24.1\nuvicorn>=0.14\npyarrow>=14.0.1\nprompt_toolkit>=3.0.8\npygments>=2.7.1\nscikit-learn>=1.0.0\nintake>=0.6.0\npre-commit>=2.11.1\nblack=22.10.0\nisort=5.12.0\nmaturin>=1.3,<1.4\n"
  },
  {
    "path": "continuous_integration/docker/main.dockerfile",
    "content": "# Dockerfile for dask-sql running the SQL server\n# For more information, see https://dask-sql.readthedocs.io/.\nFROM daskdev/dask:latest\nLABEL author \"Nils Braun <nilslennartbraun@gmail.com>\"\n\n# Install rustc & gcc for compilation of DataFusion planner\nADD https://sh.rustup.rs /rustup-init.sh\nRUN sh /rustup-init.sh -y --default-toolchain=stable --profile=minimal \\\n    && apt-get update \\\n    && apt-get install gcc -y\nENV PATH=\"/root/.cargo/bin:${PATH}\"\n\n# Install conda dependencies for dask-sql\nCOPY continuous_integration/docker/conda.txt /opt/dask_sql/\nRUN mamba install -y \\\n    # build requirements\n    \"maturin>=1.3,<1.4\" \\\n    # core dependencies\n    \"dask>=2024.4.1\" \\\n    \"pandas>=1.4.0\" \\\n    \"fastapi>=0.92.0\" \\\n    \"httpx>=0.24.1\" \\\n    \"uvicorn>=0.14\" \\\n    \"tzlocal>=2.1\" \\\n    \"prompt_toolkit>=3.0.8\" \\\n    \"pygments>=2.7.1\" \\\n    tabulate \\\n    # additional dependencies\n    \"pyarrow>=14.0.1\" \\\n    \"scikit-learn>=1.0.0\" \\\n    \"intake>=0.6.0\" \\\n    && conda clean -ay\n\n# install dask-sql\nCOPY Cargo.toml /opt/dask_sql/\nCOPY Cargo.lock /opt/dask_sql/\nCOPY pyproject.toml /opt/dask_sql/\nCOPY setup.cfg /opt/dask_sql/\nCOPY README.md /opt/dask_sql/\nCOPY .git /opt/dask_sql/.git\nCOPY src /opt/dask_sql/src\nCOPY dask_sql /opt/dask_sql/dask_sql\nRUN cd /opt/dask_sql/ \\\n    && CONDA_PREFIX=\"/opt/conda/\" maturin develop\n\n# Set the script to execute\nCOPY continuous_integration/scripts/startup_script.py /opt/dask_sql/startup_script.py\n\nEXPOSE 8080\nENTRYPOINT [ \"/usr/bin/prepare.sh\", \"/opt/conda/bin/python\", \"/opt/dask_sql/startup_script.py\" ]\n"
  },
  {
    "path": "continuous_integration/environment-3.10.yaml",
    "content": "name: dask-sql\nchannels:\n- conda-forge\ndependencies:\n- c-compiler\n- dask>=2024.4.1\n- dask-expr>=1.0.11\n- docker-py>=7.1.0\n- fastapi>=0.92.0\n- fugue>=0.7.3\n- httpx>=0.24.1\n- intake>=0.6.0\n- jsonschema\n- lightgbm\n- maturin>=1.3,<1.4\n- mlflow>=2.10\n- mock\n- numpy>=1.22.4\n- pandas>=2\n- pre-commit\n- prompt_toolkit>=3.0.8\n- psycopg2\n- pyarrow>=14.0.1\n- pygments>=2.7.1\n- pyhive\n- pytest-cov\n- pytest-rerunfailures\n- pytest-xdist\n- pytest\n- python=3.10\n- py-xgboost>=2.0.3\n- scikit-learn>=1.0.0\n- sphinx\n- sqlalchemy\n- tpot>=0.12.0\n# FIXME: https://github.com/fugue-project/fugue/issues/526\n- triad<0.9.2\n- tzlocal>=2.1\n- uvicorn>=0.14\n- zlib\n"
  },
  {
    "path": "continuous_integration/environment-3.11.yaml",
    "content": "name: dask-sql\nchannels:\n- conda-forge\ndependencies:\n- c-compiler\n- dask>=2024.4.1\n- dask-expr>=1.0.11\n- docker-py>=7.1.0\n- fastapi>=0.92.0\n- fugue>=0.7.3\n- httpx>=0.24.1\n- intake>=0.6.0\n- jsonschema\n- lightgbm\n- maturin>=1.3,<1.4\n- mlflow>=2.10\n- mock\n- numpy>=1.22.4\n- pandas>=2\n- pre-commit\n- prompt_toolkit>=3.0.8\n- psycopg2\n- pyarrow>=14.0.1\n- pygments>=2.7.1\n- pyhive\n- pytest-cov\n- pytest-rerunfailures\n- pytest-xdist\n- pytest\n- python=3.11\n- py-xgboost>=2.0.3\n- scikit-learn>=1.0.0\n- sphinx\n- sqlalchemy\n- tpot>=0.12.0\n# FIXME: https://github.com/fugue-project/fugue/issues/526\n- triad<0.9.2\n- tzlocal>=2.1\n- uvicorn>=0.14\n- zlib\n"
  },
  {
    "path": "continuous_integration/environment-3.12.yaml",
    "content": "name: dask-sql\nchannels:\n- conda-forge\ndependencies:\n- c-compiler\n- dask>=2024.4.1\n- dask-expr>=1.0.11\n- docker-py>=7.1.0\n- fastapi>=0.92.0\n- fugue>=0.7.3\n- httpx>=0.24.1\n- intake>=0.6.0\n- jsonschema\n- lightgbm\n- maturin>=1.3,<1.4\n# TODO: add once mlflow 3.12 builds are available\n# - mlflow>=2.10\n- mock\n- numpy>=1.22.4\n- pandas>=2\n- pre-commit\n- prompt_toolkit>=3.0.8\n- psycopg2\n- pyarrow>=14.0.1\n- pygments>=2.7.1\n- pyhive\n- pytest-cov\n- pytest-rerunfailures\n- pytest-xdist\n- pytest\n- python=3.12\n- py-xgboost>=2.0.3\n- scikit-learn>=1.0.0\n- sphinx\n- sqlalchemy\n# TODO: add once tpot supports python 3.12\n# - tpot>=0.12.0\n# FIXME: https://github.com/fugue-project/fugue/issues/526\n- triad<0.9.2\n- tzlocal>=2.1\n- uvicorn>=0.14\n- zlib\n"
  },
  {
    "path": "continuous_integration/environment-3.9.yaml",
    "content": "name: dask-sql-py39\nchannels:\n- conda-forge\ndependencies:\n- c-compiler\n- dask=2024.4.1\n- dask-expr=1.0.11\n- docker-py>=7.1.0\n- fastapi=0.92.0\n- fugue=0.7.3\n- httpx=0.24.1\n- intake=0.6.0\n- jsonschema\n- lightgbm\n- maturin=1.3\n- mlflow=2.10\n- mock\n- numpy=1.22.4\n- pandas=2\n- pre-commit\n- prompt_toolkit=3.0.8\n- psycopg2\n- pyarrow=14.0.1\n- pygments=2.7.1\n- pyhive\n- pytest-cov\n- pytest-rerunfailures\n- pytest-xdist\n- pytest\n- python=3.9\n- py-xgboost=2.0.3\n- scikit-learn=1.0.0\n- sphinx\n- sqlalchemy\n- tpot>=0.12.0\n# FIXME: https://github.com/fugue-project/fugue/issues/526\n- triad<0.9.2\n- tzlocal=2.1\n- uvicorn=0.14\n- zlib\n"
  },
  {
    "path": "continuous_integration/gpuci/environment-3.10.yaml",
    "content": "name: dask-sql\nchannels:\n- rapidsai\n- rapidsai-nightly\n- dask/label/dev\n- conda-forge\n- nvidia\n- nodefaults\ndependencies:\n- c-compiler\n- zlib\n- dask>=2024.4.1\n- dask-expr>=1.0.11\n- fastapi>=0.92.0\n- fugue>=0.7.3\n- httpx>=0.24.1\n- intake>=0.6.0\n- jsonschema\n- lightgbm\n- maturin>=1.3,<1.4\n- mock\n- numpy>=1.22.4\n- pandas>=2\n- pre-commit\n- prompt_toolkit>=3.0.8\n- psycopg2\n- pyarrow>=14.0.1\n- pygments>=2.7.1\n- pyhive\n- pytest-cov\n- pytest-rerunfailures\n- pytest-xdist\n- pytest\n- python=3.10\n- py-xgboost>=2.0.3\n- scikit-learn>=1.0.0\n- sphinx\n- sqlalchemy\n- tpot>=0.12.0\n# FIXME: https://github.com/fugue-project/fugue/issues/526\n- triad<0.9.2\n- tzlocal>=2.1\n- uvicorn>=0.14\n# GPU-specific requirements\n- cudatoolkit=11.8\n- cudf=24.06\n- cuml=24.06\n- dask-cudf=24.06\n- dask-cuda=24.06\n- ucx-proc=*=gpu\n- ucx-py=0.38\n- xgboost=*=rapidsai_py*\n- libxgboost=*=rapidsai_h*\n"
  },
  {
    "path": "continuous_integration/gpuci/environment-3.11.yaml",
    "content": "name: dask-sql\nchannels:\n- rapidsai\n- rapidsai-nightly\n- dask/label/dev\n- conda-forge\n- nvidia\n- nodefaults\ndependencies:\n- c-compiler\n- zlib\n- dask>=2024.4.1\n- dask-expr>=1.0.11\n- fastapi>=0.92.0\n- fugue>=0.7.3\n- httpx>=0.24.1\n- intake>=0.6.0\n- jsonschema\n- lightgbm\n- maturin>=1.3,<1.4\n- mock\n- numpy>=1.22.4\n- pandas>=2\n- pre-commit\n- prompt_toolkit>=3.0.8\n- psycopg2\n- pyarrow>=14.0.1\n- pygments>=2.7.1\n- pyhive\n- pytest-cov\n- pytest-rerunfailures\n- pytest-xdist\n- pytest\n- python=3.11\n- py-xgboost>=2.0.3\n- scikit-learn>=1.0.0\n- sphinx\n- sqlalchemy\n- tpot>=0.12.0\n# FIXME: https://github.com/fugue-project/fugue/issues/526\n- triad<0.9.2\n- tzlocal>=2.1\n- uvicorn>=0.14\n# GPU-specific requirements\n- cudatoolkit=11.8\n- cudf=24.06\n- cuml=24.06\n- dask-cudf=24.06\n- dask-cuda=24.06\n- ucx-proc=*=gpu\n- ucx-py=0.38\n- xgboost=*=rapidsai_py*\n- libxgboost=*=rapidsai_h*\n"
  },
  {
    "path": "continuous_integration/gpuci/environment-3.9.yaml",
    "content": "name: dask-sql\nchannels:\n- rapidsai\n- rapidsai-nightly\n- dask/label/dev\n- conda-forge\n- nvidia\n- nodefaults\ndependencies:\n- c-compiler\n- zlib\n- dask>=2024.4.1\n- dask-expr>=1.0.11\n- fastapi>=0.92.0\n- fugue>=0.7.3\n- httpx>=0.24.1\n- intake>=0.6.0\n- jsonschema\n- lightgbm\n- maturin>=1.3,<1.4\n- mock\n- numpy>=1.22.4\n- pandas>=2\n- pre-commit\n- prompt_toolkit>=3.0.8\n- psycopg2\n- pyarrow>=14.0.1\n- pygments>=2.7.1\n- pyhive\n- pytest-cov\n- pytest-rerunfailures\n- pytest-xdist\n- pytest\n- python=3.9\n- py-xgboost==2.0.3\n- scikit-learn>=1.0.0\n- sphinx\n- sqlalchemy\n- tpot>=0.12.0\n# FIXME: https://github.com/fugue-project/fugue/issues/526\n- triad<0.9.2\n- tzlocal>=2.1\n- uvicorn>=0.14\n# GPU-specific requirements\n- cudatoolkit=11.8\n- cudf=24.06\n- cuml=24.06\n- dask-cudf=24.06\n- dask-cuda=24.06\n- ucx-proc=*=gpu\n- ucx-py=0.38\n- xgboost=*=rapidsai_py*\n- libxgboost=*=rapidsai_h*\n"
  },
  {
    "path": "continuous_integration/recipe/build.sh",
    "content": "#!/bin/bash\n\nset -ex\n\n# See https://github.com/conda-forge/rust-feedstock/blob/master/recipe/build.sh for cc env explanation\nif [ \"$c_compiler\" = gcc ] ; then\n    case \"$target_platform\" in\n        linux-64) rust_env_arch=X86_64_UNKNOWN_LINUX_GNU ;;\n        linux-aarch64) rust_env_arch=AARCH64_UNKNOWN_LINUX_GNU ;;\n        linux-ppc64le) rust_env_arch=POWERPC64LE_UNKNOWN_LINUX_GNU ;;\n        *) echo \"unknown target_platform $target_platform\" ; exit 1 ;;\n    esac\n\n    export CARGO_TARGET_${rust_env_arch}_LINKER=$CC\nfi\n\ndeclare -a _xtra_maturin_args\n\nmkdir -p $SRC_DIR/.cargo\n\nif [ \"$target_platform\" = \"osx-64\" ] ; then\n    cat <<EOF >> $SRC_DIR/.cargo/config\n[target.x86_64-apple-darwin]\nlinker = \"$CC\"\nrustflags = [\n  \"-C\", \"link-arg=-undefined\",\n  \"-C\", \"link-arg=dynamic_lookup\",\n]\n\nEOF\n\n    _xtra_maturin_args+=(--target=x86_64-apple-darwin)\n\nelif [ \"$target_platform\" = \"osx-arm64\" ] ; then\n    cat <<EOF >> $SRC_DIR/.cargo/config\n# Required for intermediate codegen stuff\n[target.x86_64-apple-darwin]\nlinker = \"$CC_FOR_BUILD\"\n\n# Required for final binary artifacts for target\n[target.aarch64-apple-darwin]\nlinker = \"$CC\"\nrustflags = [\n  \"-C\", \"link-arg=-undefined\",\n  \"-C\", \"link-arg=dynamic_lookup\",\n]\n\nEOF\n    _xtra_maturin_args+=(--target=aarch64-apple-darwin)\n\n    # This variable must be set to the directory containing the target's libpython DSO\n    export PYO3_CROSS_LIB_DIR=$PREFIX/lib\n\n    # xref: https://github.com/PyO3/pyo3/commit/7beb2720\n    export PYO3_PYTHON_VERSION=${PY_VER}\n\n    # xref: https://github.com/conda-forge/python-feedstock/issues/621\n    sed -i.bak 's,aarch64,arm64,g' $BUILD_PREFIX/venv/lib/os-patch.py\n    sed -i.bak 's,aarch64,arm64,g' $BUILD_PREFIX/venv/lib/platform-patch.py\nfi\n\nmaturin build -vv -j \"${CPU_COUNT}\" --release --strip --manylinux off --interpreter=\"${PYTHON}\" \"${_xtra_maturin_args[@]}\"\n\n\"${PYTHON}\" -m pip install $SRC_DIR/target/wheels/dask_sql*.whl --no-deps -vv\n"
  },
  {
    "path": "continuous_integration/recipe/conda_build_config.yaml",
    "content": "c_compiler:\n    - gcc\nc_compiler_version:\n    - '12'\nrust_compiler:\n    - rust\nrust_compiler_version:\n    - '1.72'\nmaturin:\n    - '1.3'\nxz:        # [linux64]\n    - '5'  # [linux64]\n"
  },
  {
    "path": "continuous_integration/recipe/meta.yaml",
    "content": "{% set name = \"dask-sql\" %}\n{% set major_minor_patch = environ.get('GIT_DESCRIBE_TAG', '0.0.0.dev').split('.') %}\n{% set new_patch = major_minor_patch[2] | int + 1 %}\n{% set version = (major_minor_patch[:2] + [new_patch]) | join('.') + environ.get('VERSION_SUFFIX', '') %}\n\n\npackage:\n  name: {{ name|lower }}\n  version: {{ version }}\n\nsource:\n  git_url: ../..\n\nbuild:\n  number: {{ GIT_DESCRIBE_NUMBER }}\n  entry_points:\n    - dask-sql-server = dask_sql.server.app:main\n    - dask-sql = dask_sql.cmd:main\n  string: py{{ python | replace(\".\", \"\") }}_{{ GIT_DESCRIBE_HASH }}_{{ GIT_DESCRIBE_NUMBER }}\n\nrequirements:\n  build:\n    - python                                 # [build_platform != target_platform]\n    - cross-python_{{ target_platform }}     # [build_platform != target_platform]\n    - maturin                                # [build_platform != target_platform]\n    - {{ compiler('c') }}\n    - {{ compiler('rust') }}\n  host:\n    - pip\n    - python\n    - maturin\n    - xz  # [linux64]\n  run:\n    - python\n    - dask >=2024.4.1\n    - pandas >=1.4.0\n    - fastapi >=0.92.0\n    - httpx >=0.24.1\n    - uvicorn >=0.14\n    - tzlocal >=2.1\n    - prompt-toolkit >=3.0.8\n    - pygments >=2.7.1\n    - tabulate\n\ntest:\n  imports:\n    - dask_sql\n  commands:\n    - pip check\n    - dask-sql-server --help\n    - dask-sql --help\n  requires:\n    - pip\n\nabout:\n  home: https://github.com/dask-contrib/dask-sql/\n  summary: SQL query layer for Dask\n  license: MIT\n  license_file: LICENSE.txt\n"
  },
  {
    "path": "continuous_integration/recipe/run_test.py",
    "content": "import dask.dataframe as dd\nimport pandas as pd\n\nfrom dask_sql import Context\n\nc = Context()\n\ndata = \"\"\"\nname,x\nAlice,34\nBob,\n\"\"\"\n\ndf = pd.DataFrame({\"name\": [\"Alice\", \"Bob\", \"Chris\"] * 100, \"x\": list(range(300))})\nddf = dd.from_pandas(df, npartitions=10)\n\n# This needs to be temprarily disabled since this query requires features that are not yet implemented\n# c.create_table(\"my_data\", ddf)\n\n# got = c.sql(\n#     \"\"\"\n#     SELECT\n#         my_data.name,\n#         SUM(my_data.x) AS \"S\"\n#     FROM\n#         my_data\n#     GROUP BY\n#         my_data.name\n# \"\"\"\n# )\n# expect = pd.DataFrame({\"name\": [\"Alice\", \"Bob\", \"Chris\"], \"S\": [14850, 14950, 15050]})\n\n# dd.assert_eq(got, expect)\n"
  },
  {
    "path": "continuous_integration/scripts/startup_script.py",
    "content": "from dask_sql.server.app import main\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "continuous_integration/scripts/update-dependencies.sh",
    "content": "#!/bin/bash\n\nUPDATE_ALL_CARGO_DEPS=\"${UPDATE_ALL_CARGO_DEPS:-true}\"\n# Update datafusion dependencies in the dask-planner to the latest revision from the default branch\nsed -i -r 's/^datafusion-([a-z]+).*/datafusion-\\1 = { git = \"https:\\/\\/github.com\\/apache\\/arrow-datafusion-python\\/\" }/g' Cargo.toml\n\nif [ \"$UPDATE_ALL_CARGO_DEPS\" = true ] ; then\n    cargo update\nfi\n"
  },
  {
    "path": "dask_sql/__init__.py",
    "content": "# FIXME: can we modify TLS model of Rust object to avoid aarch64 glibc bug?\n# https://github.com/dask-contrib/dask-sql/issues/1169\nfrom . import _datafusion_lib  # isort:skip\n\nimport importlib.metadata\n\nfrom dask.config import set\n\nfrom . import config\nfrom .cmd import cmd_loop\nfrom .context import Context\nfrom .datacontainer import Statistics\nfrom .server.app import run_server\n\n# TODO: get pyarrow strings and p2p shuffle working\nset(dataframe__convert_string=False, dataframe__shuffle__method=\"tasks\")\n\n__version__ = importlib.metadata.version(__name__)\n\n__all__ = [__version__, cmd_loop, Context, run_server, Statistics]\n"
  },
  {
    "path": "dask_sql/_compat.py",
    "content": "import prompt_toolkit\nfrom packaging.version import parse as parseVersion\n\n_prompt_toolkit_version = parseVersion(prompt_toolkit.__version__)\n\n# TODO: remove if prompt-toolkit min version gets bumped\nPIPE_INPUT_CONTEXT_MANAGER = _prompt_toolkit_version >= parseVersion(\"3.0.29\")\n"
  },
  {
    "path": "dask_sql/cmd.py",
    "content": "import logging\nimport os\nimport sys\nimport tempfile\nimport traceback\nfrom argparse import ArgumentParser\nfrom functools import partial\nfrom typing import Union\n\nimport pandas as pd\nfrom dask.datasets import timeseries\nfrom dask.distributed import Client, as_completed\nfrom prompt_toolkit.auto_suggest import AutoSuggestFromHistory\nfrom prompt_toolkit.completion import WordCompleter\nfrom prompt_toolkit.history import FileHistory\nfrom prompt_toolkit.shortcuts import ProgressBar\nfrom pygments.lexers.sql import SqlLexer\n\ntry:\n    # prompt_toolkit version >= 2\n    from prompt_toolkit.lexers import PygmentsLexer\nexcept ImportError:  # pragma: no cover\n    # prompt_toolkit version < 2\n    from prompt_toolkit.layout.lexers import PygmentsLexer\n\nfrom dask_sql.context import Context\n\nmeta_command_completer = WordCompleter(\n    [\"\\\\l\", \"\\\\d?\", \"\\\\dt\", \"\\\\df\", \"\\\\de\", \"\\\\dm\", \"\\\\conninfo\", \"quit\"]\n)\n\n\nclass CompatiblePromptSession:\n    \"\"\"\n    Session object wrapper for the prompt_toolkit module\n\n    In the version jump from 1 to 2, the prompt_toolkit\n    introduced a PromptSession object.\n    Some environments however (e.g. google collab)\n    still rely on an older prompt_toolkit version,\n    so we try to support both versions\n    with this wrapper object.\n    All it does is export a `prompt` function.\n    \"\"\"\n\n    def __init__(self, lexer) -> None:  # pragma: no cover\n        # make sure everytime dask-sql  uses same history file\n        kwargs = {\n            \"lexer\": lexer,\n            \"history\": FileHistory(\n                os.path.join(tempfile.gettempdir(), \"dask-sql-history\")\n            ),\n            \"auto_suggest\": AutoSuggestFromHistory(),\n            \"completer\": meta_command_completer,\n        }\n        try:\n            # Version >= 2.0.1: we can use the session object\n            from prompt_toolkit import PromptSession\n\n            session = PromptSession(**kwargs)\n            self.prompt = session.prompt\n        except ImportError:\n            # Version < 2.0: there is no session object\n            from prompt_toolkit.shortcuts import prompt\n\n            self.prompt = partial(prompt, **kwargs)\n\n\ndef _display_markdown(content, **kwargs):\n    df = pd.DataFrame(content, **kwargs)\n    print(df.to_markdown(tablefmt=\"fancy_grid\"))\n\n\ndef _parse_meta_command(sql):\n    command, _, arg = sql.partition(\" \")\n    return command, arg.strip()\n\n\ndef _meta_commands(sql: str, context: Context, client: Client) -> Union[bool, Client]:\n    \"\"\"\n    parses metacommands and prints their result\n    returns True if meta commands detected\n    \"\"\"\n    cmd, schema_name = _parse_meta_command(sql)\n    available_commands = [\n        [\"\\\\l\", \"List schemas\"],\n        [\"\\\\d?, help, ?\", \"Show available commands\"],\n        [\"\\\\conninfo\", \"Show Dask cluster info\"],\n        [\"\\\\dt [schema]\", \"List tables\"],\n        [\"\\\\df [schema]\", \"List functions\"],\n        [\"\\\\dm [schema]\", \"List models\"],\n        [\"\\\\de [schema]\", \"List experiments\"],\n        [\"\\\\dss [schema]\", \"Switch schema\"],\n        [\"\\\\dsc [dask scheduler address]\", \"Switch Dask cluster\"],\n        [\"quit\", \"Quits dask-sql-cli\"],\n    ]\n    if cmd == \"\\\\dsc\":\n        # Switch Dask cluster\n        _, scheduler_address = _parse_meta_command(sql)\n        client = Client(scheduler_address)\n        return client  # pragma: no cover\n    schema_name = schema_name or context.schema_name\n    if cmd == \"\\\\d?\" or cmd == \"help\" or cmd == \"?\":\n        _display_markdown(available_commands, columns=[\"Commands\", \"Description\"])\n    elif cmd == \"\\\\l\":\n        _display_markdown(context.schema.keys(), columns=[\"Schemas\"])\n    elif cmd == \"\\\\dt\":\n        _display_markdown(context.schema[schema_name].tables.keys(), columns=[\"Tables\"])\n    elif cmd == \"\\\\df\":\n        _display_markdown(\n            context.schema[schema_name].functions.keys(), columns=[\"Functions\"]\n        )\n    elif cmd == \"\\\\de\":\n        _display_markdown(\n            context.schema[schema_name].experiments.keys(), columns=[\"Experiments\"]\n        )\n    elif cmd == \"\\\\dm\":\n        _display_markdown(context.schema[schema_name].models.keys(), columns=[\"Models\"])\n    elif cmd == \"\\\\conninfo\":\n        cluster_info = [\n            [\"Dask scheduler\", client.scheduler.__dict__[\"addr\"]],\n            [\"Dask dashboard\", client.dashboard_link],\n            [\"Cluster status\", client.status],\n            [\"Dask workers\", len(client.cluster.workers)],\n        ]\n        _display_markdown(\n            cluster_info, columns=[\"components\", \"value\"]\n        )  # pragma: no cover\n    elif cmd == \"\\\\dss\":\n        if schema_name in context.schema:\n            context.schema_name = schema_name\n        else:\n            print(f\"Schema {schema_name} not available\")\n    elif cmd == \"quit\":\n        print(\"Quitting dask-sql ...\")\n        client.close()  # for safer side\n        sys.exit()\n    elif cmd.startswith(\"\\\\\"):\n        print(\n            f\"The meta command {cmd} not available, please use commands from below list\"\n        )\n        _display_markdown(available_commands, columns=[\"Commands\", \"Description\"])\n    else:\n        # nothing detected probably not a meta command\n        return False\n    return True\n\n\ndef cmd_loop(\n    context: Context = None,\n    client: Client = None,\n    startup=False,\n    log_level=None,\n):  # pragma: no cover\n    \"\"\"\n    Run a REPL for answering SQL queries using ``dask-sql``.\n    Every SQL expression that ``dask-sql`` understands can be used here.\n\n    Args:\n        context (:obj:`dask_sql.Context`): If set, use this context instead of an empty one.\n        client (:obj:`dask.distributed.Client`): If set, use this dask client instead of a new one.\n        startup (:obj:`bool`): Whether to wait until Apache Calcite was loaded\n        log_level: (:obj:`str`): The log level of the server and dask-sql\n\n    Example:\n        It is possible to run a REPL by using the CLI script in ``dask-sql``\n        or by calling this function directly in your user code:\n\n        .. code-block:: python\n\n            from dask_sql import cmd_loop\n\n            # Create your pre-filled context\n            c = Context()\n            ...\n\n            cmd_loop(context=c)\n\n        Of course, it is also possible to call the usual ``CREATE TABLE``\n        commands.\n    \"\"\"\n    pd.set_option(\"display.max_rows\", None)\n    pd.set_option(\"display.max_columns\", None)\n    pd.set_option(\"display.width\", None)\n    pd.set_option(\"display.max_colwidth\", None)\n\n    logging.basicConfig(level=log_level)\n\n    client = client or Client()\n    context = context or Context()\n\n    if startup:\n        context.sql(\"SELECT 1 + 1\").compute()\n\n    session = CompatiblePromptSession(lexer=PygmentsLexer(SqlLexer))\n\n    while True:\n        try:\n            text = session.prompt(\"(dask-sql) > \")\n        except KeyboardInterrupt:\n            continue\n        except EOFError:\n            break\n\n        text = text.rstrip(\";\").strip()\n\n        if not text:\n            continue\n\n        meta_command_detected = _meta_commands(text, context=context, client=client)\n        if isinstance(meta_command_detected, Client):\n            client = meta_command_detected\n\n        if not meta_command_detected:\n            try:\n                df = context.sql(text, return_futures=True)\n                if df is not None:  # some sql commands returns None\n                    df = df.persist()\n                    # Now turn it into a list of futures\n                    futures = client.futures_of(df)\n                    with ProgressBar() as pb:\n                        for _ in pb(\n                            as_completed(futures), total=len(futures), label=\"Executing\"\n                        ):\n                            continue\n                        df = df.compute()\n                        print(df.to_markdown(tablefmt=\"fancy_grid\"))\n\n            except Exception:\n                traceback.print_exc()\n\n\ndef main():  # pragma: no cover\n    parser = ArgumentParser()\n    parser.add_argument(\n        \"--scheduler-address\",\n        default=None,\n        help=\"Connect to this dask scheduler if given\",\n    )\n    parser.add_argument(\n        \"--log-level\",\n        default=None,\n        help=\"Set the log level of the server. Defaults to info.\",\n        choices=[\"DEBUG\", \"INFO\", \"WARNING\", \"ERROR\"],\n    )\n    parser.add_argument(\n        \"--load-test-data\",\n        default=False,\n        action=\"store_true\",\n        help=\"Preload some test data.\",\n    )\n    parser.add_argument(\n        \"--startup\",\n        default=False,\n        action=\"store_true\",\n        help=\"Wait until Apache Calcite was properly loaded\",\n    )\n\n    args = parser.parse_args()\n\n    client = None\n    if args.scheduler_address:\n        client = Client(args.scheduler_address)\n\n    context = Context()\n    if args.load_test_data:\n        df = timeseries(freq=\"1d\").reset_index(drop=False)\n        context.create_table(\"timeseries\", df.persist())\n\n    cmd_loop(\n        context=context, client=client, startup=args.startup, log_level=args.log_level\n    )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "dask_sql/config.py",
    "content": "import os\n\nimport dask\nimport yaml\n\nfn = os.path.join(os.path.dirname(__file__), \"sql.yaml\")\n\nwith open(fn) as f:\n    defaults = yaml.safe_load(f)\n\ndask.config.update_defaults(defaults)\ndask.config.ensure_file(source=fn, comment=True)\n"
  },
  {
    "path": "dask_sql/context.py",
    "content": "import asyncio\nimport inspect\nimport logging\nfrom collections import Counter\nfrom typing import Any, Callable, Union\n\nimport dask.dataframe as dd\nimport pandas as pd\nfrom dask import config as dask_config\nfrom dask.base import optimize\nfrom dask.utils_test import hlg_layer\n\nfrom dask_sql._datafusion_lib import (\n    DaskSchema,\n    DaskSQLContext,\n    DaskSQLOptimizerConfig,\n    DaskTable,\n    DFOptimizationException,\n    DFParsingException,\n    LogicalPlan,\n)\n\ntry:\n    from dask_sql.physical.utils.statistics import parquet_statistics\nexcept ModuleNotFoundError:\n    parquet_statistics = None\n\ntry:\n    import dask_cuda  # noqa: F401\nexcept ImportError:  # pragma: no cover\n    pass\n\nfrom dask_sql import input_utils\nfrom dask_sql.datacontainer import (\n    UDF,\n    DataContainer,\n    FunctionDescription,\n    SchemaContainer,\n    Statistics,\n)\nfrom dask_sql.input_utils import InputType, InputUtil\nfrom dask_sql.integrations.ipython import ipython_integration\nfrom dask_sql.mappings import python_to_sql_type\nfrom dask_sql.physical.rel import RelConverter, custom, logical\nfrom dask_sql.physical.rex import RexConverter, core\nfrom dask_sql.utils import ParsingException\n\nlogger = logging.getLogger(__name__)\n\n\nclass Context:\n    \"\"\"\n    Main object to communicate with ``dask_sql``.\n    It holds a store of all registered data frames (= tables)\n    and can convert SQL queries to dask data frames.\n    The tables in these queries are referenced by the name,\n    which is given when registering a dask dataframe.\n\n    Example:\n        .. code-block:: python\n\n            from dask_sql import Context\n            c = Context()\n\n            # Register a table\n            c.create_table(\"my_table\", df)\n\n            # Now execute an SQL query. The result is a dask dataframe\n            result = c.sql(\"SELECT a, b FROM my_table\")\n\n            # Trigger the computation (or use the data frame for something else)\n            result.compute()\n\n    Usually, you will only ever have a single context in your program.\n\n    See also:\n        :func:`sql`\n        :func:`create_table`\n\n    \"\"\"\n\n    DEFAULT_CATALOG_NAME = \"dask_sql\"\n    DEFAULT_SCHEMA_NAME = \"root\"\n\n    def __init__(self, logging_level=logging.INFO):\n        \"\"\"\n        Create a new context.\n        \"\"\"\n\n        # Set the logging level for this SQL context\n        logging.basicConfig(level=logging_level)\n\n        # Name of the root catalog\n        self.catalog_name = self.DEFAULT_CATALOG_NAME\n        # Name of the root schema\n        self.schema_name = self.DEFAULT_SCHEMA_NAME\n        # All schema information\n        self.schema = {self.schema_name: SchemaContainer(self.schema_name)}\n        # A started SQL server (useful for jupyter notebooks)\n        self.sql_server = None\n\n        # Create the `DaskSQLOptimizerConfig` Rust context\n        optimizer_config = DaskSQLOptimizerConfig(\n            dask_config.get(\"sql.dynamic_partition_pruning\"),\n            dask_config.get(\"sql.fact_dimension_ratio\"),\n            dask_config.get(\"sql.max_fact_tables\"),\n            dask_config.get(\"sql.preserve_user_order\"),\n            dask_config.get(\"sql.filter_selectivity\"),\n        )\n\n        # Create the `DaskSQLContext` Rust context\n        self.context = DaskSQLContext(\n            self.catalog_name, self.schema_name, optimizer_config\n        )\n        self.context.register_schema(self.schema_name, DaskSchema(self.schema_name))\n\n        # # Register any default plugins, if nothing was registered before.\n        RelConverter.add_plugin_class(logical.DaskAggregatePlugin, replace=False)\n        RelConverter.add_plugin_class(logical.DaskCrossJoinPlugin, replace=False)\n        RelConverter.add_plugin_class(logical.DaskEmptyRelationPlugin, replace=False)\n        RelConverter.add_plugin_class(logical.DaskFilterPlugin, replace=False)\n        RelConverter.add_plugin_class(logical.DaskJoinPlugin, replace=False)\n        RelConverter.add_plugin_class(logical.DaskLimitPlugin, replace=False)\n        RelConverter.add_plugin_class(logical.DaskProjectPlugin, replace=False)\n        RelConverter.add_plugin_class(logical.DaskSortPlugin, replace=False)\n        RelConverter.add_plugin_class(logical.DaskTableScanPlugin, replace=False)\n        RelConverter.add_plugin_class(logical.DaskUnionPlugin, replace=False)\n        RelConverter.add_plugin_class(logical.DaskValuesPlugin, replace=False)\n        RelConverter.add_plugin_class(logical.DaskWindowPlugin, replace=False)\n        RelConverter.add_plugin_class(logical.SamplePlugin, replace=False)\n        RelConverter.add_plugin_class(logical.ExplainPlugin, replace=False)\n        RelConverter.add_plugin_class(logical.SubqueryAlias, replace=False)\n        RelConverter.add_plugin_class(custom.AnalyzeTablePlugin, replace=False)\n        RelConverter.add_plugin_class(custom.CreateExperimentPlugin, replace=False)\n        RelConverter.add_plugin_class(custom.CreateModelPlugin, replace=False)\n        RelConverter.add_plugin_class(custom.CreateCatalogSchemaPlugin, replace=False)\n        RelConverter.add_plugin_class(custom.CreateMemoryTablePlugin, replace=False)\n        RelConverter.add_plugin_class(custom.CreateTablePlugin, replace=False)\n        RelConverter.add_plugin_class(custom.DropModelPlugin, replace=False)\n        RelConverter.add_plugin_class(custom.DropSchemaPlugin, replace=False)\n        RelConverter.add_plugin_class(custom.DropTablePlugin, replace=False)\n        RelConverter.add_plugin_class(custom.ExportModelPlugin, replace=False)\n        RelConverter.add_plugin_class(custom.PredictModelPlugin, replace=False)\n        RelConverter.add_plugin_class(custom.ShowColumnsPlugin, replace=False)\n        RelConverter.add_plugin_class(custom.DescribeModelPlugin, replace=False)\n        RelConverter.add_plugin_class(custom.ShowModelsPlugin, replace=False)\n        RelConverter.add_plugin_class(custom.ShowSchemasPlugin, replace=False)\n        RelConverter.add_plugin_class(custom.ShowTablesPlugin, replace=False)\n        RelConverter.add_plugin_class(custom.UseSchemaPlugin, replace=False)\n        RelConverter.add_plugin_class(custom.AlterSchemaPlugin, replace=False)\n        RelConverter.add_plugin_class(custom.AlterTablePlugin, replace=False)\n        RelConverter.add_plugin_class(custom.DistributeByPlugin, replace=False)\n\n        RexConverter.add_plugin_class(core.RexAliasPlugin, replace=False)\n        RexConverter.add_plugin_class(core.RexCallPlugin, replace=False)\n        RexConverter.add_plugin_class(core.RexInputRefPlugin, replace=False)\n        RexConverter.add_plugin_class(core.RexLiteralPlugin, replace=False)\n        RexConverter.add_plugin_class(core.RexScalarSubqueryPlugin, replace=False)\n\n        InputUtil.add_plugin_class(input_utils.DaskInputPlugin, replace=False)\n        InputUtil.add_plugin_class(input_utils.PandasLikeInputPlugin, replace=False)\n        InputUtil.add_plugin_class(input_utils.HiveInputPlugin, replace=False)\n        InputUtil.add_plugin_class(input_utils.IntakeCatalogInputPlugin, replace=False)\n        InputUtil.add_plugin_class(input_utils.SqlalchemyHiveInputPlugin, replace=False)\n        # needs to be the last entry, as it only checks for string\n        InputUtil.add_plugin_class(input_utils.LocationInputPlugin, replace=False)\n\n    def create_table(\n        self,\n        table_name: str,\n        input_table: InputType,\n        format: str = None,\n        persist: bool = False,\n        schema_name: str = None,\n        statistics: Statistics = None,\n        gpu: bool = False,\n        **kwargs,\n    ):\n        \"\"\"\n        Registering a (dask/pandas) table makes it usable in SQL queries.\n        The name you give here can be used as table name in the SQL later.\n\n        Please note, that the table is stored as it is now.\n        If you change the table later, you need to re-register.\n\n        Instead of passing an already loaded table, it is also possible\n        to pass a string to a storage location.\n        The library will then try to load the data using one of\n        `dask's read methods <https://docs.dask.org/en/latest/dataframe-create.html>`_.\n        If the file format can not be deduced automatically, it is also\n        possible to specify it via the ``format`` parameter.\n        Typical file formats are csv or parquet.\n        Any additional parameters will get passed on to the read method.\n        Please note that some file formats require additional libraries.\n        By default, the data will be lazily loaded. If you would like to\n        load the data directly into memory you can do so by setting\n        persist=True.\n\n        See :ref:`data_input` for more information.\n\n        Example:\n            This code registers a data frame as table \"data\"\n            and then uses it in a query.\n\n            .. code-block:: python\n\n                c.create_table(\"data\", df)\n                df_result = c.sql(\"SELECT a, b FROM data\")\n\n            This code reads a file from disk.\n            Please note that we assume that the file(s) are reachable under this path\n            from every node in the cluster\n\n            .. code-block:: python\n\n                c.create_table(\"data\", \"/home/user/data.csv\")\n                df_result = c.sql(\"SELECT a, b FROM data\")\n\n            This example reads from a hive table.\n\n            .. code-block:: python\n\n                from pyhive.hive import connect\n\n                cursor = connect(\"localhost\", 10000).cursor()\n                c.create_table(\"data\", cursor, hive_table_name=\"the_name_in_hive\")\n                df_result = c.sql(\"SELECT a, b FROM data\")\n\n        Args:\n            table_name: (:obj:`str`): Under which name should the new table be addressable\n            input_table (:class:`dask.dataframe.DataFrame` or :class:`pandas.DataFrame` or :obj:`str` or :class:`hive.Cursor`):\n                The data frame/location/hive connection to register.\n            format (:obj:`str`): Only used when passing a string into the ``input`` parameter.\n                Specify the file format directly here if it can not be deduced from the extension.\n                If set to \"memory\", load the data from a published dataset in the dask cluster.\n            persist (:obj:`bool`): Only used when passing a string into the ``input`` parameter.\n                Set to true to turn on loading the file data directly into memory.\n            schema_name: (:obj:`str`): in which schema to create the table. By default, will use the currently selected schema.\n            statistics: (:obj:`Statistics`): if given, use these statistics during the cost-based optimization.\n            gpu: (:obj:`bool`): if set to true, use dask-cudf to run the data frame calculations on your GPU.\n                Please note that the GPU support is currently not covering all of dask-sql's SQL language.\n            **kwargs: Additional arguments for specific formats. See :ref:`data_input` for more information.\n\n        \"\"\"\n        logger.debug(\n            f\"Creating table: '{table_name}' of format type '{format}' in schema '{schema_name}'\"\n        )\n\n        schema_name = schema_name or self.schema_name\n\n        dc = InputUtil.to_dc(\n            input_table,\n            table_name=table_name,\n            format=format,\n            persist=persist,\n            gpu=gpu,\n            **kwargs,\n        )\n\n        if type(input_table) == str:\n            dc.filepath = input_table\n            self.schema[schema_name].filepaths[table_name.lower()] = input_table\n        elif hasattr(input_table, \"dask\") and dd.utils.is_dataframe_like(input_table):\n            try:\n                if dd._dask_expr_enabled():\n                    from dask_expr.io.parquet import ReadParquet\n\n                    dask_filepath = None\n                    operations = input_table.find_operations(ReadParquet)\n                    for op in operations:\n                        dask_filepath = op._args[0]\n                else:\n                    dask_filepath = hlg_layer(\n                        input_table.dask, \"read-parquet\"\n                    ).creation_info[\"args\"][0]\n                dc.filepath = dask_filepath\n                self.schema[schema_name].filepaths[table_name.lower()] = dask_filepath\n            except KeyError:\n                logger.debug(\"Expected 'read-parquet' layer\")\n\n        if parquet_statistics and not dd._dask_expr_enabled() and not statistics:\n            statistics = parquet_statistics(dc.df)\n            if statistics:\n                row_count = 0\n                for d in statistics:\n                    row_count += d[\"num-rows\"]\n                statistics = Statistics(row_count)\n        if not statistics:\n            statistics = Statistics(float(\"nan\"))\n        dc.statistics = statistics\n\n        self.schema[schema_name].tables[table_name.lower()] = dc\n        self.schema[schema_name].statistics[table_name.lower()] = statistics\n\n    def drop_table(self, table_name: str, schema_name: str = None):\n        \"\"\"\n        Remove a table with the given name from the registered tables.\n        This will also delete the dataframe.\n\n        Args:\n            table_name: (:obj:`str`): Which table to remove.\n\n        \"\"\"\n        schema_name = schema_name or self.schema_name\n        del self.schema[schema_name].tables[table_name]\n\n    def drop_schema(self, schema_name: str):\n        \"\"\"\n        Remove a schema with the given name from the registered schemas.\n        This will also delete all tables, functions etc.\n\n        Args:\n            schema_name: (:obj:`str`): Which schema to remove.\n\n        \"\"\"\n        if schema_name == self.DEFAULT_SCHEMA_NAME:\n            raise RuntimeError(f\"Default Schema `{schema_name}` cannot be deleted\")\n\n        del self.schema[schema_name]\n\n        if self.schema_name == schema_name:\n            self.schema_name = self.DEFAULT_SCHEMA_NAME\n\n    def register_function(\n        self,\n        f: Callable,\n        name: str,\n        parameters: list[tuple[str, type]],\n        return_type: type,\n        replace: bool = False,\n        schema_name: str = None,\n        row_udf: bool = False,\n    ):\n        \"\"\"\n        Register a custom function with the given name.\n        The function can be used (with this name)\n        in every SQL queries from now on - but only for scalar operations\n        (no aggregations).\n        This means, if you register a function \"f\", you can now call\n\n        .. code-block:: sql\n\n            SELECT f(x)\n            FROM df\n\n        Please keep in mind that you can only have one function with the same name,\n        regardless of whether it is an aggregation or a scalar function. By default,\n        attempting to register two functions with the same name will raise an error;\n        setting `replace=True` will give precedence to the most recently registered\n        function.\n\n        For the registration, you need to supply both the\n        list of parameter and parameter types as well as the\n        return type. Use `numpy dtypes <https://numpy.org/doc/stable/reference/arrays.dtypes.html>`_ if possible.\n\n        More information: :ref:`custom`\n\n        Example:\n            This example registers a function \"f\", which\n            calculates the square of an integer and applies\n            it to the column ``x``.\n\n            .. code-block:: python\n\n                def f(x):\n                    return x ** 2\n\n                c.register_function(f, \"f\", [(\"x\", np.int64)], np.int64)\n\n                sql = \"SELECT f(x) FROM df\"\n                df_result = c.sql(sql)\n\n        Example of overwriting two functions with the same name:\n            This example registers a different function \"f\", which\n            calculates the floor division of an integer and applies\n            it to the column ``x``. It also shows how to overwrite\n            the previous function with the replace parameter.\n\n            .. code-block:: python\n\n                def f(x):\n                    return x // 2\n\n                c.register_function(f, \"f\", [(\"x\", np.int64)], np.int64, replace=True)\n\n                sql = \"SELECT f(x) FROM df\"\n                df_result = c.sql(sql)\n\n        Args:\n            f (:obj:`Callable`): The function to register\n            name (:obj:`str`): Under which name should the new function be addressable in SQL\n            parameters (:obj:`List[Tuple[str, type]]`): A list ot tuples of parameter name and parameter type.\n                Use `numpy dtypes <https://numpy.org/doc/stable/reference/arrays.dtypes.html>`_ if possible. This\n                function is sensitive to the order of specified parameters when `row_udf=True`, and it is assumed\n                that column arguments are specified in order, followed by scalar arguments.\n            return_type (:obj:`type`): The return type of the function\n            replace (:obj:`bool`): If `True`, do not raise an error if a function with the same name is already\n            present; instead, replace the original function. Default is `False`.\n\n        See also:\n            :func:`register_aggregation`\n\n        \"\"\"\n        self._register_callable(\n            f,\n            name,\n            aggregation=False,\n            parameters=parameters,\n            return_type=return_type,\n            replace=replace,\n            schema_name=schema_name,\n            row_udf=row_udf,\n        )\n\n    def register_aggregation(\n        self,\n        f: dd.Aggregation,\n        name: str,\n        parameters: list[tuple[str, type]],\n        return_type: type,\n        replace: bool = False,\n        schema_name: str = None,\n    ):\n        \"\"\"\n        Register a custom aggregation with the given name.\n        The aggregation can be used (with this name)\n        in every SQL queries from now on - but only for aggregation operations\n        (no scalar function calls).\n        This means, if you register a aggregation \"fagg\", you can now call\n\n        .. code-block:: sql\n\n            SELECT fagg(y)\n            FROM df\n            GROUP BY x\n\n        Please note that you can always only have one function with the same name;\n        no matter if it is an aggregation or scalar function.\n\n        For the registration, you need to supply both the\n        list of parameter and parameter types as well as the\n        return type. Use `numpy dtypes <https://numpy.org/doc/stable/reference/arrays.dtypes.html>`_  if possible.\n\n        More information: :ref:`custom`\n\n        Example:\n            The following code registers a new aggregation \"fagg\", which\n            computes the sum of a column and uses it on the ``y`` column.\n\n            .. code-block:: python\n\n                fagg = dd.Aggregation(\"fagg\", lambda x: x.sum(), lambda x: x.sum())\n                c.register_aggregation(fagg, \"fagg\", [(\"x\", np.float64)], np.float64)\n\n                sql = \"SELECT fagg(y) FROM df GROUP BY x\"\n                df_result = c.sql(sql)\n\n        Args:\n            f (:class:`dask.dataframe.Aggregate`): The aggregate to register. See\n                `the dask documentation <https://docs.dask.org/en/latest/dataframe-groupby.html#aggregate>`_\n                for more information.\n            name (:obj:`str`): Under which name should the new aggregate be addressable in SQL\n            parameters (:obj:`List[Tuple[str, type]]`): A list ot tuples of parameter name and parameter type.\n                Use `numpy dtypes <https://numpy.org/doc/stable/reference/arrays.dtypes.html>`_ if possible.\n            return_type (:obj:`type`): The return type of the function\n            replace (:obj:`bool`): Do not raise an error if the function is already present\n\n        See also:\n            :func:`register_function`\n\n        \"\"\"\n        self._register_callable(\n            f,\n            name,\n            aggregation=True,\n            parameters=parameters,\n            return_type=return_type,\n            replace=replace,\n            schema_name=schema_name,\n        )\n\n    def sql(\n        self,\n        sql: Any,\n        return_futures: bool = True,\n        dataframes: dict[str, Union[dd.DataFrame, pd.DataFrame]] = None,\n        gpu: bool = False,\n        config_options: dict[str, Any] = None,\n    ) -> Union[dd.DataFrame, pd.DataFrame]:\n        \"\"\"\n        Query the registered tables with the given SQL.\n        The SQL follows approximately the postgreSQL standard - however, not all\n        operations are already implemented.\n        In general, only select statements (no data manipulation) works.\n        For more information, see :ref:`sql`.\n\n        Example:\n            In this example, a query is called\n            using the registered tables and then\n            executed using dask.\n\n            .. code-block:: python\n\n                result = c.sql(\"SELECT a, b FROM my_table\")\n                print(result.compute())\n        Args:\n            sql (:obj:`str`): The query string to execute\n            return_futures (:obj:`bool`): Return the unexecuted dask dataframe or the data itself.\n                Defaults to returning the dask dataframe.\n            dataframes (:obj:`Dict[str, dask.dataframe.DataFrame]`): additional Dask or pandas dataframes\n                to register before executing this query\n            gpu (:obj:`bool`): Whether or not to load the additional Dask or pandas dataframes (if any) on GPU;\n                requires cuDF / dask-cuDF if enabled. Defaults to False.\n            config_options (:obj:`Dict[str,Any]`): Specific configuration options to pass during\n                query execution\n        Returns:\n            :obj:`dask.dataframe.DataFrame`: the created data frame of this query.\n        \"\"\"\n        with dask_config.set(config_options):\n            if dataframes is not None:\n                for df_name, df in dataframes.items():\n                    self.create_table(df_name, df, gpu=gpu)\n\n            if isinstance(sql, str):\n                rel, _ = self._get_ral(sql)\n            elif isinstance(sql, LogicalPlan):\n                rel = sql\n            else:\n                raise RuntimeError(\n                    f\"Encountered unsupported `LogicalPlan` sql type: {type(sql)}\"\n                )\n\n            return self._compute_table_from_rel(rel, return_futures)\n\n    def explain(\n        self,\n        sql: str,\n        dataframes: dict[str, Union[dd.DataFrame, pd.DataFrame]] = None,\n        gpu: bool = False,\n    ) -> str:\n        \"\"\"\n        Return the stringified relational algebra that this query will produce\n        once triggered (with ``sql()``).\n        Helpful to understand the inner workings of dask-sql, but typically not\n        needed to query your data.\n\n        If the query is of DDL type (e.g. CREATE TABLE or DESCRIBE SCHEMA),\n        no relational algebra plan is created and therefore nothing returned.\n\n        Args:\n            sql (:obj:`str`): The query string to use\n            dataframes (:obj:`Dict[str, dask.dataframe.DataFrame]`): additional Dask or pandas dataframes\n                to register before executing this query\n            gpu (:obj:`bool`): Whether or not to load the additional Dask or pandas dataframes (if any) on GPU;\n                requires cuDF / dask-cuDF if enabled. Defaults to False.\n\n        Returns:\n            :obj:`str`: a description of the created relational algebra.\n\n        \"\"\"\n        dynamic_partition_pruning = dask_config.get(\"sql.dynamic_partition_pruning\")\n        if not dask_config.get(\"sql.optimizer.verbose\"):\n            dask_config.set({\"sql.dynamic_partition_pruning\": False})\n\n        if dataframes is not None:\n            for df_name, df in dataframes.items():\n                self.create_table(df_name, df, gpu=gpu)\n\n        _, rel_string = self._get_ral(sql)\n        dask_config.set({\"sql.dynamic_partition_pruning\": dynamic_partition_pruning})\n        return rel_string\n\n    def visualize(self, sql: str, filename=\"mydask.png\") -> None:  # pragma: no cover\n        \"\"\"Visualize the computation of the given SQL into the png\"\"\"\n        result = self.sql(sql, return_futures=True)\n        (result,) = optimize(result)\n\n        result.visualize(filename)\n\n    def create_schema(self, schema_name: str):\n        \"\"\"\n        Create a new schema in the database.\n\n        Args:\n            schema_name (:obj:`str`): The name of the schema to create\n        \"\"\"\n        self.schema[schema_name] = SchemaContainer(schema_name)\n\n    def alter_schema(self, old_schema_name, new_schema_name):\n        \"\"\"\n        Alter schema\n\n        Args:\n             old_schema_name:\n             new_schema_name:\n        \"\"\"\n        self.schema[new_schema_name] = self.schema.pop(old_schema_name)\n\n    def alter_table(self, old_table_name, new_table_name, schema_name=None):\n        \"\"\"\n        Alter Table\n\n        Args:\n            old_table_name:\n            new_table_name:\n            schema_name:\n        \"\"\"\n        if schema_name is None:\n            schema_name = self.schema_name\n\n        self.schema[schema_name].tables[new_table_name] = self.schema[\n            schema_name\n        ].tables.pop(old_table_name)\n\n    def register_experiment(\n        self,\n        experiment_name: str,\n        experiment_results: pd.DataFrame,\n        schema_name: str = None,\n    ):\n        schema_name = schema_name or self.schema_name\n        self.schema[schema_name].experiments[\n            experiment_name.lower()\n        ] = experiment_results\n\n    def register_model(\n        self,\n        model_name: str,\n        model: Any,\n        training_columns: list[str],\n        schema_name: str = None,\n    ):\n        \"\"\"\n        Add a model to the model registry.\n        A model can be anything which has a `.predict` function that transforms\n        a Dask dataframe into predicted labels (as a Dask series).\n        After model registration, the model can be used in calls to\n        `SELECT ... FROM PREDICT` with the given name.\n        Instead of creating your own model and register it, you can also\n        train a model directly in dask-sql. See the SQL command `CrEATE MODEL`.\n\n        Args:\n            model_name (:obj:`str`): The name of the model\n            model: The model to store\n            training_columns: (list of str): The names of the columns which were\n                used during the training.\n        \"\"\"\n        schema_name = schema_name or self.schema_name\n        self.schema[schema_name].models[model_name.lower()] = (model, training_columns)\n\n    def ipython_magic(\n        self, auto_include=False, disable_highlighting=True\n    ):  # pragma: no cover\n        \"\"\"\n        Register a new ipython/jupyter magic function \"sql\"\n        which sends its input as string to the :func:`sql` function.\n        After calling this magic function in a Jupyter notebook or\n        an IPython shell, you can write\n\n        .. code-block:: python\n\n            %sql SELECT * from data\n\n        or\n\n        .. code-block:: python\n\n            %%sql\n            SELECT * from data\n\n        instead of\n\n        .. code-block:: python\n\n            c.sql(\"SELECT * from data\")\n\n        Args:\n            auto_include (:obj:`bool`): If set to true, automatically\n                create a table for every pandas or Dask dataframe in the calling\n                context. That means, if you define a dataframe in your jupyter\n                notebook you can use it with the same name in your sql call.\n                Use this setting with care as any defined dataframe can\n                easily override tables created via `CREATE TABLE`.\n\n                .. code-block:: python\n\n                    df = ...\n\n                    # Later, without any calls to create_table\n\n                    %%sql\n                    SELECT * FROM df\n\n            disable_highlighting (:obj:`bool`): If set to true, automatically\n                disable syntax highlighting. If you are working in jupyter lab,\n                diable_highlighting must be set to true to enable ipython_magic\n                functionality. If you are working in a classic jupyter notebook,\n                you may set disable_highlighting=False if desired.\n        \"\"\"\n        ipython_integration(\n            self, auto_include=auto_include, disable_highlighting=disable_highlighting\n        )\n\n    def run_server(self, **kwargs):  # pragma: no cover\n        \"\"\"\n        Run a HTTP server for answering SQL queries using ``dask-sql``.\n\n        See :ref:`server` for more information.\n\n        Args:\n            client (:obj:`dask.distributed.Client`): If set, use this dask client instead of a new one.\n            host (:obj:`str`): The host interface to listen on (defaults to all interfaces)\n            port (:obj:`int`): The port to listen on (defaults to 8080)\n            log_level: (:obj:`str`): The log level of the server and dask-sql\n        \"\"\"\n        from dask_sql.server.app import run_server\n\n        self.stop_server()\n        self.server = run_server(**kwargs)\n\n    def stop_server(self):  # pragma: no cover\n        \"\"\"\n        Stop a SQL server started by ``run_server``.\n        \"\"\"\n        if self.sql_server is not None:\n            loop = asyncio.get_event_loop()\n            assert loop\n            loop.create_task(self.sql_server.shutdown())\n\n        self.sql_server = None\n\n    def fqn(self, tbl: \"DaskTable\") -> tuple[str, str]:\n        \"\"\"\n        Return the fully qualified name of an object, maybe including the schema name.\n\n        Args:\n            tbl (:obj:`DaskTable`): The Rust DaskTable instance of the view or table.\n\n        Returns:\n            :obj:`tuple` of :obj:`str`: The fully qualified name of the object\n        \"\"\"\n        schema_name, table_name = tbl.getSchema(), tbl.getTableName()\n\n        if schema_name is None or schema_name == \"\":\n            schema_name = self.schema_name\n\n        return schema_name, table_name\n\n    def _prepare_schemas(self):\n        \"\"\"\n        Create a list of schemas filled with the dataframes\n        and functions we have currently in our schema list\n        \"\"\"\n        logger.debug(\n            f\"There are {len(self.schema)} existing schema(s): {self.schema.keys()}\"\n        )\n        schema_list = []\n\n        for schema_name, schema in self.schema.items():\n            logger.debug(f\"Preparing Schema: '{schema_name}'\")\n            rust_schema = DaskSchema(schema_name)\n\n            if not schema.tables:\n                logger.warning(\"No tables are registered.\")\n\n            for name, dc in schema.tables.items():\n                row_count = (\n                    float(schema.statistics[name].row_count)\n                    if name in schema.statistics\n                    else float(0)\n                )\n\n                filepath = schema.filepaths[name] if name in schema.filepaths else None\n                df = dc.df\n                columns = df.columns\n                cc = dc.column_container\n                if not dask_config.get(\"sql.identifier.case_sensitive\"):\n                    columns = [col.lower() for col in columns]\n                    cc = cc.rename_handle_duplicates(df.columns, columns)\n                    dc.column_container = cc\n                column_type_mapping = list(\n                    zip(columns, map(python_to_sql_type, df.dtypes))\n                )\n                table = DaskTable(\n                    schema_name, name, row_count, column_type_mapping, filepath\n                )\n\n                rust_schema.add_table(table)\n\n            if not schema.functions:\n                logger.debug(\"No custom functions defined.\")\n            for function_description in schema.function_lists:\n                name = function_description.name\n                sql_return_type = function_description.return_type\n                sql_parameters = function_description.parameters\n                if function_description.aggregation:\n                    logger.debug(f\"Adding function '{name}' to schema as aggregation.\")\n                    rust_schema.add_or_overload_function(\n                        name,\n                        [param[1].getDataType() for param in sql_parameters],\n                        sql_return_type.getDataType(),\n                        True,\n                    )\n                else:\n                    logger.debug(\n                        f\"Adding function '{name}' to schema as scalar function.\"\n                    )\n                    rust_schema.add_or_overload_function(\n                        name,\n                        [param[1].getDataType() for param in sql_parameters],\n                        sql_return_type.getDataType(),\n                        False,\n                    )\n\n            schema_list.append(rust_schema)\n\n        return schema_list\n\n    def _get_ral(self, sql):\n        \"\"\"Helper function to turn the sql query into a relational algebra and resulting column names\"\"\"\n\n        logger.debug(f\"Entering _get_ral('{sql}')\")\n\n        optimizer_config = DaskSQLOptimizerConfig(\n            dask_config.get(\"sql.dynamic_partition_pruning\"),\n            dask_config.get(\"sql.fact_dimension_ratio\"),\n            dask_config.get(\"sql.max_fact_tables\"),\n            dask_config.get(\"sql.preserve_user_order\"),\n            dask_config.get(\"sql.filter_selectivity\"),\n        )\n        self.context.set_optimizer_config(optimizer_config)\n\n        # get the schema of what we currently have registered\n        schemas = self._prepare_schemas()\n        for schema in schemas:\n            self.context.register_schema(schema.name, schema)\n        try:\n            sqlTree = self.context.parse_sql(sql)\n        except DFParsingException as pe:\n            raise ParsingException(sql, str(pe))\n        logger.debug(f\"_get_ral -> sqlTree: {sqlTree}\")\n\n        rel = sqlTree\n\n        # TODO: Need to understand if this list here is actually needed? For now just use the first entry.\n        if len(sqlTree) > 1:\n            raise RuntimeError(\n                f\"Multiple 'Statements' encountered for SQL {sql}. Please share this with the dev team!\"\n            )\n\n        try:\n            nonOptimizedRel = self.context.logical_relational_algebra(sqlTree[0])\n        except DFParsingException as pe:\n            raise ParsingException(sql, str(pe)) from None\n\n        # Optimize the `LogicalPlan` or skip if configured\n        if dask_config.get(\"sql.optimize\"):\n            try:\n                rel = self.context.run_preoptimizer(nonOptimizedRel)\n                rel = self.context.optimize_relational_algebra(rel)\n            except DFOptimizationException as oe:\n                # Use original plan and warn about inability to optimize plan\n                rel = nonOptimizedRel\n                logger.warning(str(oe))\n        else:\n            rel = nonOptimizedRel\n\n        rel_string = rel.explain_original()\n        logger.debug(f\"_get_ral -> LogicalPlan: {rel}\")\n        logger.debug(f\"Extracted relational algebra:\\n {rel_string}\")\n\n        return rel, rel_string\n\n    def _compute_table_from_rel(self, rel: \"LogicalPlan\", return_futures: bool = True):\n        dc = RelConverter.convert(rel, context=self)\n\n        if rel.get_current_node_type() == \"Explain\":\n            return dc\n        if dc is None:\n            return\n\n        # Optimization might remove some alias projects. Make sure to keep them here.\n        select_names = [field for field in rel.getRowType().getFieldList()]\n\n        if select_names:\n            cc = dc.column_container\n\n            select_names = select_names[: len(cc.columns)]\n\n            # Use FQ name if not unique and simple name if it is unique. If a join contains the same column\n            # names the output col is prepended with the fully qualified column name\n            field_counts = Counter([field.getName() for field in select_names])\n            select_names = [\n                field.getQualifiedName()\n                if field_counts[field.getName()] > 1\n                else field.getName()\n                for field in select_names\n            ]\n\n            cc = cc.rename(\n                {\n                    df_col: select_name\n                    for df_col, select_name in zip(cc.columns, select_names)\n                }\n            )\n            dc = DataContainer(dc.df, cc)\n\n        df = dc.assign()\n        if not return_futures:\n            df = df.compute()\n\n        return df\n\n    def _get_tables_from_stack(self):\n        \"\"\"Helper function to return all dask/pandas dataframes from the calling stack\"\"\"\n        stack = inspect.stack()\n\n        tables = {}\n\n        # Traverse the stacks from inside to outside\n        for frame_info in stack:\n            for var_name, variable in frame_info.frame.f_locals.items():\n                if var_name.startswith(\"_\"):\n                    continue\n                if not dd.utils.is_dataframe_like(variable):\n                    continue\n\n                # only set them if not defined in an inner context\n                tables[var_name] = tables.get(var_name, variable)\n\n        return tables\n\n    def _register_callable(\n        self,\n        f: Any,\n        name: str,\n        aggregation: bool,\n        parameters: list[tuple[str, type]],\n        return_type: type,\n        replace: bool = False,\n        schema_name=None,\n        row_udf: bool = False,\n    ):\n        \"\"\"Helper function to do the function or aggregation registration\"\"\"\n\n        schema_name = schema_name or self.schema_name\n        schema = self.schema[schema_name]\n\n        # validate and cache UDF metadata\n        sql_parameters = [\n            (name, python_to_sql_type(param_type)) for name, param_type in parameters\n        ]\n        sql_return_type = python_to_sql_type(return_type)\n\n        if not aggregation:\n            f = UDF(f, row_udf, parameters, return_type)\n        lower_name = name.lower()\n        if lower_name in schema.functions:\n            if replace:\n                schema.function_lists = list(\n                    filter(\n                        lambda f: f.name.lower() != lower_name,\n                        schema.function_lists,\n                    )\n                )\n                del schema.functions[lower_name]\n\n            elif schema.functions[lower_name] != f:\n                raise ValueError(\n                    \"Registering multiple functions with the same name is only permitted if replace=True\"\n                )\n\n        schema.function_lists.append(\n            FunctionDescription(\n                name.upper(), sql_parameters, sql_return_type, aggregation\n            )\n        )\n        schema.function_lists.append(\n            FunctionDescription(\n                name.lower(), sql_parameters, sql_return_type, aggregation\n            )\n        )\n        schema.functions[lower_name] = f\n"
  },
  {
    "path": "dask_sql/datacontainer.py",
    "content": "from collections import namedtuple\nfrom typing import Any, Union\n\nimport dask.dataframe as dd\nimport pandas as pd\n\nColumnType = Union[str, int]\n\nFunctionDescription = namedtuple(\n    \"FunctionDescription\", [\"name\", \"parameters\", \"return_type\", \"aggregation\"]\n)\n\n\nclass ColumnContainer:\n    # Forward declaration\n    pass\n\n\nclass ColumnContainer:\n    \"\"\"\n    Helper class to store a list of columns,\n    which do not necessarily be the ones of the dask dataframe.\n    Instead, the container also stores a mapping from \"frontend\"\n    columns (columns with the names and order expected by SQL)\n    to \"backend\" columns (the real column names used by dask)\n    to prevent unnecessary renames.\n    \"\"\"\n\n    def __init__(\n        self,\n        frontend_columns: list[str],\n        frontend_backend_mapping: Union[dict[str, ColumnType], None] = None,\n    ):\n        assert all(\n            isinstance(col, str) for col in frontend_columns\n        ), \"All frontend columns need to be of string type\"\n        self._frontend_columns = list(frontend_columns)\n        if frontend_backend_mapping is None:\n            self._frontend_backend_mapping = {\n                col: col for col in self._frontend_columns\n            }\n        else:\n            self._frontend_backend_mapping = frontend_backend_mapping\n\n    def _copy(self) -> ColumnContainer:\n        \"\"\"\n        Internal function to copy this container\n        \"\"\"\n        return ColumnContainer(\n            self._frontend_columns.copy(), self._frontend_backend_mapping.copy()\n        )\n\n    def limit_to(self, fields: list[str]) -> ColumnContainer:\n        \"\"\"\n        Create a new ColumnContainer, which has frontend columns\n        limited to only the ones given as parameter.\n        Also uses the order of these as the new column order.\n        \"\"\"\n        if not fields:\n            return self  # pragma: no cover\n\n        assert all(f in self._frontend_backend_mapping for f in fields)\n        cc = self._copy()\n        cc._frontend_columns = [str(x) for x in fields]\n        return cc\n\n    def rename(self, columns: dict[str, str]) -> ColumnContainer:\n        \"\"\"\n        Return a new ColumnContainer where the frontend columns\n        are renamed according to the given mapping.\n        Columns not present in the mapping are not touched,\n        the order is preserved.\n        \"\"\"\n        cc = self._copy()\n        for column_from, column_to in columns.items():\n            backend_column = self._frontend_backend_mapping[str(column_from)]\n            cc._frontend_backend_mapping[str(column_to)] = backend_column\n\n        cc._frontend_columns = [\n            str(columns[col]) if col in columns else col\n            for col in self._frontend_columns\n        ]\n\n        return cc\n\n    def rename_handle_duplicates(\n        self, from_columns: list[str], to_columns: list[str]\n    ) -> ColumnContainer:\n        \"\"\"\n        Same as `rename` but additionally handles presence of\n        duplicates in `from_columns`\n        \"\"\"\n        cc = self._copy()\n        cc._frontend_backend_mapping.update(\n            {\n                str(column_to): self._frontend_backend_mapping[str(column_from)]\n                for column_from, column_to in zip(from_columns, to_columns)\n            }\n        )\n\n        columns = dict(zip(from_columns, to_columns))\n        cc._frontend_columns = [\n            str(columns.get(col, col)) for col in self._frontend_columns\n        ]\n\n        return cc\n\n    def mapping(self) -> list[tuple[str, ColumnType]]:\n        \"\"\"\n        The mapping from frontend columns to backend columns.\n        \"\"\"\n        return list(self._frontend_backend_mapping.items())\n\n    @property\n    def columns(self) -> list[str]:\n        \"\"\"\n        The stored frontend columns in the correct order\n        \"\"\"\n        return self._frontend_columns.copy()\n\n    def add(\n        self, frontend_column: str, backend_column: Union[str, None] = None\n    ) -> ColumnContainer:\n        \"\"\"\n        Return a new ColumnContainer with the\n        given column added.\n        The column is added at the last position in the column list.\n        \"\"\"\n        cc = self._copy()\n\n        frontend_column = str(frontend_column)\n\n        cc._frontend_backend_mapping[frontend_column] = str(\n            backend_column or frontend_column\n        )\n        if frontend_column not in cc._frontend_columns:\n            cc._frontend_columns.append(frontend_column)\n\n        return cc\n\n    def get_backend_by_frontend_index(self, index: int) -> str:\n        \"\"\"\n        Get back the dask column, which is referenced by the\n        frontend (SQL) column with the given index.\n        \"\"\"\n        frontend_column = self._frontend_columns[index]\n        backend_column = self._frontend_backend_mapping[frontend_column]\n        return backend_column\n\n    def get_backend_by_frontend_name(self, column: str) -> str:\n        \"\"\"\n        Get back the dask column, which is referenced by the\n        frontend (SQL) column with the given name.\n        \"\"\"\n\n        try:\n            return self._frontend_backend_mapping[column]\n        except KeyError:\n            return column\n\n    def make_unique(self, prefix=\"col\"):\n        \"\"\"\n        Make sure we have unique column names by calling each column\n\n            <prefix>_<number>\n\n        where <number> is the column index.\n        \"\"\"\n        return self.rename(\n            columns={str(col): f\"{prefix}_{i}\" for i, col in enumerate(self.columns)}\n        )\n\n\nclass Statistics:\n    \"\"\"\n    Statistics are used during the cost-based optimization.\n    Currently, only the row count is supported, more\n    properties might follow. It needs to be provided by the user.\n    \"\"\"\n\n    def __init__(self, row_count: int) -> None:\n        self.row_count = row_count\n\n    def __eq__(self, other):\n        if isinstance(other, Statistics):\n            return self.row_count == other.row_count\n        return False\n\n\nclass DataContainer:\n    \"\"\"\n    In SQL, every column operation or reference is done via\n    the column index. Some dask operations, such as grouping,\n    joining or concatenating preserve the columns in a different\n    order than SQL would expect.\n    However, we do not want to change the column data itself\n    all the time (because this would lead to computational overhead),\n    but still would like to keep the columns accessible by name and index.\n    For this, we add an additional `ColumnContainer` to each dataframe,\n    which does all the column mapping between \"frontend\"\n    (what SQL expects, also in the correct order)\n    and \"backend\" (what dask has).\n    \"\"\"\n\n    def __init__(\n        self,\n        df: dd.DataFrame,\n        column_container: ColumnContainer,\n        statistics: Statistics = None,\n        filepath: str = None,\n    ):\n        self.df = df\n        self.column_container = column_container\n        self.statistics = statistics\n        self.filepath = filepath\n\n    def assign(self) -> dd.DataFrame:\n        \"\"\"\n        Combine the column mapping with the actual data and return\n        a dataframe which has the the columns specified in the\n        stored ColumnContainer.\n        \"\"\"\n        df = self.df[\n            [\n                self.column_container._frontend_backend_mapping[out_col]\n                for out_col in self.column_container.columns\n            ]\n        ]\n        df.columns = self.column_container.columns\n\n        return df\n\n\nclass UDF:\n    def __init__(self, func, row_udf: bool, params, return_type=None):\n        \"\"\"\n        Helper class that handles different types of UDFs and manages\n        how they should be mapped to dask operations. Two versions of\n        UDFs are supported - when `row_udf=False`, the UDF is treated\n        as expecting series-like objects as arguments and will simply\n        run those through the function. When `row_udf=True` a row udf\n        is expected and should be written to expect a dictlike object\n        containing scalars\n        \"\"\"\n        self.row_udf = row_udf\n        self.func = func\n\n        self.names = [param[0] for param in params]\n\n        self.meta = (None, return_type)\n\n    def __call__(self, *args, **kwargs):\n        if self.row_udf:\n            column_args = []\n            scalar_args = []\n            for operand in args:\n                if isinstance(operand, dd.Series):\n                    column_args.append(operand)\n                else:\n                    scalar_args.append(operand)\n\n            df = column_args[0].to_frame(self.names[0])\n            for name, col in zip(self.names[1:], column_args[1:]):\n                df[name] = col\n            result = df.apply(\n                self.func, axis=1, args=tuple(scalar_args), meta=self.meta\n            ).astype(self.meta[1])\n        else:\n            result = self.func(*args, **kwargs)\n        return result\n\n    def __eq__(self, other):\n        if isinstance(other, UDF):\n            return self.func == other.func and self.row_udf == other.row_udf\n        return NotImplemented\n\n    def __hash__(self):\n        return (self.func, self.row_udf).__hash__()\n\n\nclass SchemaContainer:\n    def __init__(self, name: str):\n        self.__name__ = name\n        self.tables: dict[str, DataContainer] = {}\n        self.statistics: dict[str, Statistics] = {}\n        self.experiments: dict[str, pd.DataFrame] = {}\n        self.models: dict[str, tuple[Any, list[str]]] = {}\n        self.functions: dict[str, UDF] = {}\n        self.function_lists: list[FunctionDescription] = []\n        self.filepaths: dict[str, str] = {}\n"
  },
  {
    "path": "dask_sql/input_utils/__init__.py",
    "content": "from .convert import InputType, InputUtil\nfrom .dask import DaskInputPlugin\nfrom .hive import HiveInputPlugin\nfrom .intake import IntakeCatalogInputPlugin\nfrom .location import LocationInputPlugin\nfrom .pandaslike import PandasLikeInputPlugin\nfrom .sqlalchemy import SqlalchemyHiveInputPlugin\n\n__all__ = [\n    InputUtil,\n    InputType,\n    DaskInputPlugin,\n    HiveInputPlugin,\n    IntakeCatalogInputPlugin,\n    LocationInputPlugin,\n    PandasLikeInputPlugin,\n    SqlalchemyHiveInputPlugin,\n]\n"
  },
  {
    "path": "dask_sql/input_utils/base.py",
    "content": "from typing import Any\n\n\nclass BaseInputPlugin:\n    def is_correct_input(\n        self, input_item: Any, table_name: str, format: str = None, **kwargs\n    ):\n        raise NotImplementedError\n\n    def to_dc(self, input_item: Any, table_name: str, format: str = None, **kwargs):\n        raise NotImplementedError\n"
  },
  {
    "path": "dask_sql/input_utils/convert.py",
    "content": "import logging\nfrom typing import TYPE_CHECKING, Union\n\nimport dask.dataframe as dd\nimport pandas as pd\n\nfrom dask_sql.datacontainer import ColumnContainer, DataContainer\nfrom dask_sql.input_utils.base import BaseInputPlugin\nfrom dask_sql.utils import Pluggable\n\nif TYPE_CHECKING:\n    import cudf\n    import hive\n    import sqlalchemy\n\nlogger = logging.Logger(__name__)\n\nInputType = Union[\n    dd.DataFrame,\n    pd.DataFrame,\n    str,\n    Union[\n        \"sqlalchemy.engine.base.Connection\",\n        \"hive.Cursor\",\n        \"cudf.core.dataframe.DataFrame\",\n    ],\n]\n\n\nclass InputUtil(Pluggable):\n    \"\"\"\n    Plugin list and helper class for transforming the inputs to\n    create table into a dask dataframe\n    \"\"\"\n\n    @classmethod\n    def add_plugin_class(cls, plugin_class: BaseInputPlugin, replace=True):\n        \"\"\"Convenience function to add a class directly to the plugins\"\"\"\n        logger.debug(f\"Registering Input plugin for {plugin_class}\")\n        cls.add_plugin(str(plugin_class), plugin_class(), replace=replace)\n\n    @classmethod\n    def to_dc(\n        cls,\n        input_item: InputType,\n        table_name: str,\n        format: str = None,\n        persist: bool = True,\n        gpu: bool = False,\n        **kwargs,\n    ) -> DataContainer:\n        \"\"\"\n        Turn possible input descriptions or formats (e.g. dask dataframes, pandas dataframes,\n        locations as string, hive tables) into the loaded data containers,\n        maybe persist them to cluster memory before.\n        \"\"\"\n        filled_get_dask_dataframe = lambda *args: cls._get_dask_dataframe(\n            *args,\n            table_name=table_name,\n            format=format,\n            gpu=gpu,\n            **kwargs,\n        )\n\n        if isinstance(input_item, list):\n            table = dd.concat([filled_get_dask_dataframe(item) for item in input_item])\n        else:\n            table = filled_get_dask_dataframe(input_item)\n\n        if persist:\n            table = table.persist()\n\n        return DataContainer(table.copy(), ColumnContainer(table.columns))\n\n    @classmethod\n    def _get_dask_dataframe(\n        cls,\n        input_item: InputType,\n        table_name: str,\n        format: str = None,\n        gpu: bool = False,\n        **kwargs,\n    ):\n        plugin_list = cls.get_plugins()\n\n        for plugin in plugin_list:\n            if plugin.is_correct_input(\n                input_item, table_name=table_name, format=format, **kwargs\n            ):\n                return plugin.to_dc(\n                    input_item, table_name=table_name, format=format, gpu=gpu, **kwargs\n                )\n\n        raise ValueError(f\"Do not understand the input type {type(input_item)}\")\n"
  },
  {
    "path": "dask_sql/input_utils/dask.py",
    "content": "from typing import Any\n\nimport dask.dataframe as dd\n\nfrom dask_sql.input_utils.base import BaseInputPlugin\n\n\nclass DaskInputPlugin(BaseInputPlugin):\n    \"\"\"Input Plugin for Dask DataFrames, just keeping them\"\"\"\n\n    def is_correct_input(\n        self, input_item: Any, table_name: str, format: str = None, **kwargs\n    ):\n        return isinstance(input_item, dd.DataFrame) or format == \"dask\"\n\n    def to_dc(\n        self,\n        input_item: Any,\n        table_name: str,\n        format: str = None,\n        gpu: bool = False,\n        **kwargs\n    ):\n        if gpu:  # pragma: no cover\n            try:\n                import dask_cudf  # noqa: F401\n            except ImportError:\n                raise ModuleNotFoundError(\n                    \"Setting `gpu=True` for table creation requires dask_cudf\"\n                )\n            return input_item.to_backend(\"cudf\", **kwargs)\n        return input_item\n"
  },
  {
    "path": "dask_sql/input_utils/hive.py",
    "content": "import ast\nimport logging\nimport os\nfrom functools import partial\nfrom typing import Any, Union\n\nimport dask.dataframe as dd\n\nfrom dask_sql._datafusion_lib import SqlTypeName\n\ntry:\n    from pyhive import hive\nexcept ImportError:  # pragma: no cover\n    hive = None\n\ntry:\n    import sqlalchemy\nexcept ImportError:  # pragma: no cover\n    sqlalchemy = None\n\nfrom dask_sql.input_utils.base import BaseInputPlugin\nfrom dask_sql.mappings import cast_column_type, sql_to_python_type\n\nlogger = logging.Logger(__name__)\n\n\nclass HiveInputPlugin(BaseInputPlugin):\n    \"\"\"Input Plugin from Hive\"\"\"\n\n    def is_correct_input(\n        self, input_item: Any, table_name: str, format: str = None, **kwargs\n    ):\n        is_hive_cursor = hive and isinstance(input_item, hive.Cursor)\n\n        return self.is_sqlalchemy_hive(input_item) or is_hive_cursor or format == \"hive\"\n\n    def is_sqlalchemy_hive(self, input_item: Any):\n        return sqlalchemy and isinstance(input_item, sqlalchemy.engine.base.Connection)\n\n    def to_dc(\n        self,\n        input_item: Any,\n        table_name: str,\n        format: str = None,\n        gpu: bool = False,\n        **kwargs,\n    ):\n        if gpu:  # pragma: no cover\n            raise Exception(\"Hive does not support gpu\")\n\n        table_name = kwargs.pop(\"hive_table_name\", table_name)\n        schema = kwargs.pop(\"hive_schema_name\", \"default\")\n\n        parsed = self._parse_hive_table_description(input_item, schema, table_name)\n        (\n            column_information,\n            table_information,\n            storage_information,\n            partition_information,\n        ) = parsed\n\n        logger.debug(\"Extracted hive information: \")\n        logger.debug(f\"column information: {column_information}\")\n        logger.debug(f\"table information: {table_information}\")\n        logger.debug(f\"storage information: {storage_information}\")\n        logger.debug(f\"partition information: {partition_information}\")\n\n        # Convert column information\n        column_information = {\n            col: sql_to_python_type(SqlTypeName.fromString(col_type.upper()))\n            for col, col_type in column_information.items()\n        }\n\n        # Extract format information\n        if \"InputFormat\" in storage_information:\n            format = storage_information[\"InputFormat\"].split(\".\")[-1]\n        # databricks format is different, see https://github.com/dask-contrib/dask-sql/issues/83\n        elif \"InputFormat\" in table_information:  # pragma: no cover\n            format = table_information[\"InputFormat\"].split(\".\")[-1]\n        else:  # pragma: no cover\n            raise RuntimeError(\n                \"Do not understand the output of 'DESCRIBE FORMATTED <table>'\"\n            )\n\n        if (\n            format == \"TextInputFormat\" or format == \"SequenceFileInputFormat\"\n        ):  # pragma: no cover\n            storage_description = storage_information.get(\"Storage Desc Params\", {})\n            read_function = partial(\n                dd.read_csv,\n                sep=storage_description.get(\"field.delim\", \",\"),\n                header=None,\n            )\n        elif format == \"ParquetInputFormat\" or format == \"MapredParquetInputFormat\":\n            read_function = dd.read_parquet\n        elif format == \"OrcInputFormat\":  # pragma: no cover\n            read_function = dd.read_orc\n        elif format == \"JsonInputFormat\":  # pragma: no cover\n            read_function = dd.read_json\n        else:  # pragma: no cover\n            raise AttributeError(f\"Do not understand hive's table format {format}\")\n\n        def _normalize(loc):\n            if loc.startswith(\"dbfs:/\") and not loc.startswith(\n                \"dbfs://\"\n            ):  # pragma: no cover\n                # dask (or better: fsspec) needs to have the URL in a specific form\n                # starting with two // after the protocol\n                loc = f\"dbfs://{loc.lstrip('dbfs:')}\"\n            # file:// is not a known protocol\n            loc = loc.lstrip(\"file:\")\n            # Only allow files which do not start with . or _\n            # Especially, not allow the _SUCCESS files\n            return os.path.join(loc, \"[A-Za-z0-9-]*\")\n\n        def wrapped_read_function(location, column_information, **kwargs):\n            location = _normalize(location)\n            logger.debug(f\"Reading in hive data from {location}\")\n            if format == \"ParquetInputFormat\" or format == \"MapredParquetInputFormat\":\n                # Hack needed for parquet files.\n                # If the folder structure is like .../col=3/...\n                # parquet wants to read in the partition information.\n                # However, we add the partition information by ourself\n                # which will lead to problems afterwards\n                # Therefore tell parquet to only read in the columns\n                # we actually care right now\n                kwargs.setdefault(\"columns\", list(column_information.keys()))\n            else:  # pragma: no cover\n                # prevent python to optimize it away and make coverage not respect the\n                # pragma\n                dummy = 0  # noqa: F841\n            df = read_function(location, **kwargs)\n\n            logger.debug(f\"Applying column information: {column_information}\")\n            df = df.rename(columns=dict(zip(df.columns, column_information.keys())))\n\n            for col, expected_type in column_information.items():\n                df = cast_column_type(df, col, expected_type)\n\n            return df\n\n        if partition_information:\n            partition_list = self._parse_hive_partition_description(\n                input_item, schema, table_name\n            )\n            logger.debug(f\"Reading in partitions from {partition_list}\")\n\n            tables = []\n            for partition in partition_list:\n                parsed = self._parse_hive_table_description(\n                    input_item, schema, table_name, partition=partition\n                )\n                (\n                    partition_column_information,\n                    partition_table_information,\n                    _,\n                    _,\n                ) = parsed\n\n                location = partition_table_information[\"Location\"]\n                table = wrapped_read_function(\n                    location, partition_column_information, **kwargs\n                )\n\n                # Now add the additional partition columns\n                partition_values = ast.literal_eval(\n                    partition_table_information[\"Partition Value\"]\n                )\n                # multiple partition column values returned comma separated string\n                if \",\" in partition_values:\n                    partition_values = [x.strip() for x in partition_values.split(\",\")]\n\n                logger.debug(\n                    f\"Applying additional partition information as columns: {partition_information}\"\n                )\n\n                partition_id = 0\n                for partition_key, partition_type in partition_information.items():\n                    table[partition_key] = partition_values[partition_id]\n                    table = cast_column_type(table, partition_key, partition_type)\n\n                    partition_id += 1\n\n                tables.append(table)\n\n            return dd.concat(tables)\n\n        location = table_information[\"Location\"]\n        df = wrapped_read_function(location, column_information, **kwargs)\n        return df\n\n    def _parse_hive_table_description(\n        self,\n        cursor: Union[\"sqlalchemy.engine.base.Connection\", \"hive.Cursor\"],\n        schema: str,\n        table_name: str,\n        partition: str = None,\n    ):\n        \"\"\"\n        Extract all information from the output\n        of the DESCRIBE FORMATTED call, which is unfortunately\n        in a format not easily readable by machines.\n        \"\"\"\n        cursor.execute(\n            sqlalchemy.text(f\"USE {schema}\")\n            if self.is_sqlalchemy_hive(cursor)\n            else f\"USE {schema}\"\n        )\n        if partition:\n            # Hive wants quoted, comma separated list of partition keys\n            partition = partition.replace(\"=\", '=\"')\n            partition = partition.replace(\"/\", '\",') + '\"'\n            result = self._fetch_all_results(\n                cursor, f\"DESCRIBE FORMATTED {table_name} PARTITION ({partition})\"\n            )\n        else:\n            result = self._fetch_all_results(cursor, f\"DESCRIBE FORMATTED {table_name}\")\n\n        logger.debug(f\"Got information from hive: {result}\")\n\n        table_information = {}\n        column_information = {}  # using the fact that dicts are insertion ordered\n        storage_information = {}\n        partition_information = {}\n        mode = \"column\"\n        last_field = None\n\n        for key, value, value2 in result:\n            key = key.strip().rstrip(\":\") if key else \"\"\n            value = value.strip() if value else \"\"\n            value2 = value2.strip() if value2 else \"\"\n\n            # That is just a comment line, we can skip it\n            if key == \"# col_name\":\n                continue\n\n            if (\n                key == \"# Detailed Table Information\"\n                or key == \"# Detailed Partition Information\"\n            ):\n                mode = \"table\"\n            elif key == \"# Storage Information\":\n                mode = \"storage\"\n            elif key == \"# Partition Information\":\n                mode = \"partition\"\n            elif key.startswith(\"#\"):\n                mode = None  # pragma: no cover\n            elif key:\n                if not value:\n                    value = dict()\n                if mode == \"column\":\n                    column_information[key] = value\n                    last_field = column_information[key]\n                elif mode == \"storage\":\n                    storage_information[key] = value\n                    last_field = storage_information[key]\n                elif mode == \"table\":\n                    # Hive partition values come in a bracketed list\n                    # quoted partition values work regardless of partition column type\n                    if key == \"Partition Value\":\n                        value = '\"' + value.strip(\"[]\") + '\"'\n                    table_information[key] = value\n                    last_field = table_information[key]\n                elif mode == \"partition\":\n                    partition_information[key] = value\n                    last_field = partition_information[key]\n                else:  # pragma: no cover\n                    # prevent python to optimize it away and make coverage not respect the\n                    # pragma\n                    dummy = 0  # noqa: F841\n            elif value and last_field is not None:\n                last_field[value] = value2\n\n        return (\n            column_information,\n            table_information,\n            storage_information,\n            partition_information,\n        )\n\n    def _parse_hive_partition_description(\n        self,\n        cursor: Union[\"sqlalchemy.engine.base.Connection\", \"hive.Cursor\"],\n        schema: str,\n        table_name: str,\n    ):\n        \"\"\"\n        Extract all partition informaton for a given table\n        \"\"\"\n        cursor.execute(\n            sqlalchemy.text(f\"USE {schema}\")\n            if self.is_sqlalchemy_hive(cursor)\n            else f\"USE {schema}\"\n        )\n        result = self._fetch_all_results(cursor, f\"SHOW PARTITIONS {table_name}\")\n\n        return [row[0] for row in result]\n\n    def _fetch_all_results(\n        self,\n        cursor: Union[\"sqlalchemy.engine.base.Connection\", \"hive.Cursor\"],\n        sql: str,\n    ):\n        \"\"\"\n        The pyhive.Cursor and the sqlalchemy connection behave slightly different.\n        The former has the fetchall method on the cursor,\n        whereas the latter on the executed query.\n        \"\"\"\n        result = cursor.execute(\n            sqlalchemy.text(sql) if self.is_sqlalchemy_hive(cursor) else sql\n        )\n\n        try:\n            return result.fetchall()\n        except AttributeError:  # pragma: no cover\n            return cursor.fetchall()\n"
  },
  {
    "path": "dask_sql/input_utils/intake.py",
    "content": "from typing import Any\n\ntry:\n    import intake\nexcept ImportError:  # pragma: no cover\n    intake = None\n\nfrom dask_sql.input_utils.base import BaseInputPlugin\n\n\nclass IntakeCatalogInputPlugin(BaseInputPlugin):\n    \"\"\"Input Plugin for Intake Catalogs, getting the table in dask format\"\"\"\n\n    def is_correct_input(\n        self, input_item: Any, table_name: str, format: str = None, **kwargs\n    ):\n        return intake and (\n            isinstance(input_item, intake.catalog.Catalog) or format == \"intake\"\n        )\n\n    def to_dc(\n        self,\n        input_item: Any,\n        table_name: str,\n        format: str = None,\n        gpu: bool = False,\n        **kwargs,\n    ):\n        if gpu:  # pragma: no cover\n            raise NotImplementedError(\"Intake does not support gpu\")\n\n        table_name = kwargs.pop(\"intake_table_name\", table_name)\n        catalog_kwargs = kwargs.pop(\"catalog_kwargs\", {})\n\n        if isinstance(input_item, str):\n            input_item = intake.open_catalog(input_item, **catalog_kwargs)\n\n        return input_item[table_name].to_dask(**kwargs)\n"
  },
  {
    "path": "dask_sql/input_utils/location.py",
    "content": "import os\nfrom typing import Any\n\nimport dask.dataframe as dd\nfrom distributed.client import default_client\n\nfrom dask_sql.input_utils.base import BaseInputPlugin\nfrom dask_sql.input_utils.convert import InputUtil\n\n\nclass LocationInputPlugin(BaseInputPlugin):\n    \"\"\"Input Plugin for everything, which can be read in from a file (on disk, remote etc.)\"\"\"\n\n    def is_correct_input(\n        self, input_item: Any, table_name: str, format: str = None, **kwargs\n    ):\n        return isinstance(input_item, str)\n\n    def to_dc(\n        self,\n        input_item: Any,\n        table_name: str,\n        format: str = None,\n        gpu: bool = False,\n        **kwargs,\n    ):\n        if format == \"memory\":\n            client = default_client()\n            df = client.get_dataset(input_item, **kwargs)\n\n            plugin_list = InputUtil.get_plugins()\n\n            for plugin in plugin_list:\n                if plugin.is_correct_input(df, table_name, format, **kwargs):\n                    return plugin.to_dc(df, table_name, format, gpu, **kwargs)\n        if not format:\n            _, extension = os.path.splitext(input_item)\n\n            format = extension.lstrip(\".\")\n        try:\n            if gpu:  # pragma: no cover\n                try:\n                    import dask_cudf\n                except ImportError:\n                    raise ModuleNotFoundError(\n                        \"Setting `gpu=True` for table creation requires dask-cudf\"\n                    )\n                read_function = getattr(dask_cudf, f\"read_{format}\")\n            else:\n                read_function = getattr(dd, f\"read_{format}\")\n        except AttributeError:\n            raise AttributeError(f\"Can not read files of format {format}\")\n\n        return read_function(input_item, **kwargs)\n"
  },
  {
    "path": "dask_sql/input_utils/pandaslike.py",
    "content": "import dask.dataframe as dd\nimport pandas as pd\n\nfrom dask_sql.input_utils.base import BaseInputPlugin\n\n\nclass PandasLikeInputPlugin(BaseInputPlugin):\n    \"\"\"Input Plugin for Pandas Like DataFrames, which get converted to dask DataFrames\"\"\"\n\n    def is_correct_input(\n        self, input_item, table_name: str, format: str = None, **kwargs\n    ):\n        return (\n            dd.utils.is_dataframe_like(input_item)\n            and not isinstance(input_item, dd.DataFrame)\n        ) or format == \"dask\"\n\n    def to_dc(\n        self,\n        input_item,\n        table_name: str,\n        format: str = None,\n        gpu: bool = False,\n        **kwargs,\n    ):\n        npartitions = kwargs.pop(\"npartitions\", 1)\n        if gpu:  # pragma: no cover\n            try:\n                import cudf\n            except ImportError:\n                raise ModuleNotFoundError(\n                    \"Setting `gpu=True` for table creation requires cudf\"\n                )\n\n            if isinstance(input_item, pd.DataFrame):\n                input_item = cudf.from_pandas(input_item)\n\n        return dd.from_pandas(input_item, npartitions=npartitions, **kwargs)\n"
  },
  {
    "path": "dask_sql/input_utils/sqlalchemy.py",
    "content": "from typing import Any\n\nfrom dask_sql.input_utils.hive import HiveInputPlugin\n\n\nclass SqlalchemyHiveInputPlugin(HiveInputPlugin):\n    \"\"\"Input Plugin from sqlalchemy string\"\"\"\n\n    def is_correct_input(\n        self, input_item: Any, table_name: str, format: str = None, **kwargs\n    ):\n        correct_prefix = isinstance(input_item, str) and (\n            input_item.startswith(\"hive://\")\n            or input_item.startswith(\"databricks+pyhive://\")\n        )\n        return correct_prefix\n\n    def to_dc(\n        self,\n        input_item: Any,\n        table_name: str,\n        format: str = None,\n        gpu: bool = False,\n        **kwargs\n    ):  # pragma: no cover\n        if gpu:\n            raise NotImplementedError(\"Hive does not support gpu\")\n\n        import sqlalchemy\n\n        engine_kwargs = {}\n        if \"connect_args\" in kwargs:\n            engine_kwargs[\"connect_args\"] = kwargs.pop(\"connect_args\")\n\n        if format is not None:\n            raise AttributeError(\n                \"Format specified and sqlalchemy connection string set!\"\n            )\n\n        cursor = sqlalchemy.create_engine(input_item, **engine_kwargs).connect()\n        return super().to_dc(cursor, table_name=table_name, **kwargs)\n"
  },
  {
    "path": "dask_sql/integrations/__init__.py",
    "content": ""
  },
  {
    "path": "dask_sql/integrations/fugue.py",
    "content": "try:\n    import fugue\n    import fugue_dask\n    from dask.distributed import Client\n    from fugue import WorkflowDataFrame, register_execution_engine\n    from fugue_sql import FugueSQLWorkflow\n    from triad import run_at_def\n    from triad.utils.convert import get_caller_global_local_vars\nexcept ImportError:  # pragma: no cover\n    raise ImportError(\n        \"Can not load the fugue module. If you want to use this integration, you need to install it.\"\n    )\n\nfrom typing import Any, Optional\n\nimport dask.dataframe as dd\n\nfrom dask_sql.context import Context\n\n\n@run_at_def\ndef _register_engines() -> None:\n    \"\"\"Register (overwrite) the default Dask execution engine of Fugue. This\n    function is invoked as an entrypoint, users don't need to call it explicitly.\n    \"\"\"\n    register_execution_engine(\n        \"dask\",\n        lambda conf, **kwargs: DaskSQLExecutionEngine(conf=conf),\n        on_dup=\"overwrite\",\n    )\n\n    register_execution_engine(\n        Client,\n        lambda engine, conf, **kwargs: DaskSQLExecutionEngine(\n            dask_client=engine, conf=conf\n        ),\n        on_dup=\"overwrite\",\n    )\n\n\nclass DaskSQLEngine(fugue.execution.execution_engine.SQLEngine):\n    \"\"\"\n    SQL engine for fugue which uses dask-sql instead of the native\n    SQL implementation.\n\n    Please note, that so far the native SQL engine in fugue\n    understands a larger set of SQL commands, but in turns is\n    (on average) slower in computation and scaling.\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        \"\"\"Create a new instance.\"\"\"\n        super().__init__(*args, **kwargs)\n\n    @property\n    def is_distributed(self) -> bool:\n        return True\n\n    def select(\n        self, dfs: fugue.dataframe.DataFrames, statement: str\n    ) -> fugue.dataframe.DataFrame:\n        \"\"\"Send the SQL command to the dask-sql context and register all temporary dataframes\"\"\"\n        c = Context()\n\n        for k, v in dfs.items():\n            c.create_table(k, self.execution_engine.to_df(v).native)\n\n        df = c.sql(statement)\n        return fugue_dask.dataframe.DaskDataFrame(df)\n\n\nclass DaskSQLExecutionEngine(fugue_dask.DaskExecutionEngine):\n    \"\"\"\n    Execution engine for fugue which has dask-sql as SQL engine\n    configured.\n\n    Please note, that so far the native SQL engine in fugue\n    understands a larger set of SQL commands, but in turns is\n    (on average) slower in computation and scaling.\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        \"\"\"Create a new instance.\"\"\"\n        super().__init__(*args, **kwargs)\n        self._default_sql_engine = DaskSQLEngine(self)\n\n    @property\n    def default_sql_engine(self) -> fugue.execution.execution_engine.SQLEngine:\n        return self._default_sql_engine\n\n\ndef fsql_dask(\n    sql: str,\n    ctx: Optional[Context] = None,\n    register: bool = False,\n    fugue_conf: Any = None,\n) -> dict[str, dd.DataFrame]:\n    \"\"\"FugueSQL utility function that can consume Context directly. FugueSQL is a language\n    extending standard SQL. It makes SQL eligible to describe end to end workflows. It also\n    enables you to invoke python extensions in the SQL like language.\n\n    For more, please read\n    `FugueSQL Tutorial <https://fugue-tutorials.readthedocs.io/en/latest/tutorials/fugue_sql/index.html/>`_\n\n    Args:\n        sql (:obj:`str`): Fugue SQL statement\n        ctx (:class:`dask_sql.Context`): The context to operate on, defaults to None\n        register (:obj:`bool`): Whether to register named steps back to the context\n          (if provided), defaults to False\n        fugue_conf (:obj:`Any`): a dictionary like object containing Fugue specific configs\n\n    Example:\n        .. code-block:: python\n\n            # define a custom prepartition function for FugueSQL\n            def median(df: pd.DataFrame) -> pd.DataFrame:\n                df[\"y\"] = df[\"y\"].median()\n                return df.head(1)\n\n            # create a context with some tables\n            c = Context()\n            ...\n\n            # run a FugueSQL query using the context as input\n            query = '''\n                j = SELECT df1.*, df2.x\n                    FROM df1 INNER JOIN df2 ON df1.key = df2.key\n                    PERSIST\n                TAKE 5 ROWS PREPARTITION BY x PRESORT key\n                PRINT\n                TRANSFORM j PREPARTITION BY x USING median\n                PRINT\n                '''\n            result = fsql_dask(query, c, register=True)\n\n            assert \"j\" in result\n            assert \"j\" in c.tables\n    \"\"\"\n    _global, _local = get_caller_global_local_vars()\n\n    dag = FugueSQLWorkflow()\n    dfs = (\n        {}\n        if ctx is None\n        else {k: dag.df(v.df) for k, v in ctx.schema[ctx.schema_name].tables.items()}\n    )\n    result = dag._sql(sql, _global, _local, **dfs)\n    dag.run(DaskSQLExecutionEngine(conf=fugue_conf))\n\n    result_dfs = {\n        k: v.result.native\n        for k, v in result.items()\n        if isinstance(v, WorkflowDataFrame)\n    }\n    if register and ctx is not None:\n        for k, v in result_dfs.items():\n            ctx.create_table(k, v)\n    return result_dfs\n"
  },
  {
    "path": "dask_sql/integrations/ipython.py",
    "content": "import time\nfrom typing import TYPE_CHECKING\n\nfrom dask_sql.mappings import _SQL_TO_PYTHON_FRAMES\nfrom dask_sql.physical.rex.core import RexCallPlugin\n\nif TYPE_CHECKING:\n    import dask_sql\n\n# That is definitely not pretty, but there seems to be no better way...\nKEYWORDS = [\n    \"and\",\n    \"as\",\n    \"asc\",\n    \"between\",\n    \"by\",\n    \"columns\",\n    \"count\",\n    \"create\",\n    \"delete\",\n    \"desc\",\n    \"describe\",\n    \"distinct\",\n    \"exists\",\n    \"from\",\n    \"group\",\n    \"having\",\n    \"if\",\n    \"in\",\n    \"inner\",\n    \"insert\",\n    \"into\",\n    \"is\",\n    \"join\",\n    \"left\",\n    \"like\",\n    \"model\",\n    \"not\",\n    \"on\",\n    \"or\",\n    \"order\",\n    \"outer\",\n    \"right\",\n    \"schemas\",\n    \"select\",\n    \"set\",\n    \"show\",\n    \"table\",\n    \"union\",\n    \"where\",\n]\n\n\ndef ipython_integration(\n    context: \"dask_sql.Context\",\n    auto_include: bool,\n    disable_highlighting: bool,\n) -> None:  # pragma: no cover\n    \"\"\"Integrate the context with jupyter notebooks. Have a look into :ref:`Context.ipython_magic`.\"\"\"\n    _register_ipython_magic(context, auto_include=auto_include)\n    if not disable_highlighting:\n        _register_syntax_highlighting()\n\n\ndef _register_ipython_magic(\n    c: \"dask_sql.Context\", auto_include: bool\n) -> None:  # pragma: no cover\n    from IPython.core.magic import needs_local_scope, register_line_cell_magic\n\n    @needs_local_scope\n    def sql(line, cell, local_ns):\n        if cell is None:\n            # the magic function was called inline\n            cell = line\n\n        sql_statement = cell.format(**local_ns)\n\n        dataframes = {}\n        if auto_include:\n            dataframes = c._get_tables_from_stack()\n\n        t0 = time.time()\n        res = c.sql(sql_statement, return_futures=False, dataframes=dataframes)\n        if (\n            \"CREATE OR REPLACE TABLE\" in sql_statement\n            or \"CREATE OR REPLACE VIEW\" in sql_statement\n        ):\n            table = sql_statement.split(\"CREATE OR REPLACE\")[1]\n            table = table.replace(\"TABLE\", \"\").replace(\"VIEW\", \"\").split()[0].strip()\n            res = c.sql(f\"SELECT * FROM {table}\").tail()\n        elif \"CREATE TABLE\" in sql_statement or \"CREATE VIEW\" in sql_statement:\n            table = sql_statement.split(\"CREATE\")[1]\n            table = table.replace(\"TABLE\", \"\").replace(\"VIEW\", \"\").split()[0].strip()\n            res = c.sql(f\"SELECT * FROM {table}\").tail()\n        print(f\"Execution time: {time.time() - t0:.2f}s\")\n        return res\n\n    # Register a new magic function\n    magic_func = register_line_cell_magic(sql)\n    magic_func.MAGIC_NO_VAR_EXPAND_ATTR = True\n\n\ndef _register_syntax_highlighting():  # pragma: no cover\n    import json\n\n    from IPython.core import display\n\n    # JS snippet to use the created mime type highlighthing\n    _JS_ENABLE_DASK_SQL = r\"\"\"\n    require(['notebook/js/codecell'], function(codecell) {\n        codecell.CodeCell.options_default.highlight_modes['magic_text/x-dasksql'] = {'reg':[/%%sql/]} ;\n        Jupyter.notebook.events.on('kernel_ready.Kernel', function(){\n        Jupyter.notebook.get_cells().map(function(cell){\n            if (cell.cell_type == 'code'){ cell.auto_highlight(); } }) ;\n        });\n    });\n    \"\"\"\n\n    types = map(str, _SQL_TO_PYTHON_FRAMES.keys())\n    functions = list(RexCallPlugin.OPERATION_MAPPING.keys())\n\n    # Create a new mimetype\n    mime_type = {\n        \"name\": \"sql\",\n        \"keywords\": _create_set(KEYWORDS + functions),\n        \"builtin\": _create_set(types),\n        \"atoms\": _create_set([\"false\", \"true\", \"null\"]),\n        # \"operatorChars\": /^[*\\/+\\-%<>!=~&|^]/,\n        \"dateSQL\": _create_set([\"time\"]),\n        # More information\n        # https://opensource.apple.com/source/WebInspectorUI/WebInspectorUI-7600.8.3/UserInterface/External/CodeMirror/sql.js.auto.html\n        \"support\": _create_set([\"ODBCdotTable\", \"doubleQuote\", \"zerolessFloat\"]),\n    }\n\n    # Code original from fugue-sql, adjusted for dask-sql and using some more customizations\n    js = (\n        r\"\"\"\n    require([\"codemirror/lib/codemirror\"]);\n\n    // We define a new mime type for syntax highlighting\n    CodeMirror.defineMIME(\"text/x-dasksql\", \"\"\"\n        + json.dumps(mime_type)\n        + r\"\"\"\n    );\n    CodeMirror.modeInfo.push({\n        name: \"Dask SQL\",\n        mime: \"text/x-dasksql\",\n        mode: \"sql\"\n    });\n    \"\"\"\n    )\n\n    display.display_javascript(js + _JS_ENABLE_DASK_SQL, raw=True)\n\n\ndef _create_set(keys: list[str]) -> dict[str, bool]:  # pragma: no cover\n    \"\"\"Small helper function to turn a list into the correct format for codemirror\"\"\"\n    return {key: True for key in keys}\n"
  },
  {
    "path": "dask_sql/mappings.py",
    "content": "import logging\nfrom datetime import datetime\nfrom typing import Any\n\nimport dask.array as da\nimport dask.config as dask_config\nimport dask.dataframe as dd\nimport numpy as np\nimport pandas as pd\n\nfrom dask_sql._datafusion_lib import DaskTypeMap, SqlTypeName\n\nlogger = logging.getLogger(__name__)\n\n\n# Default mapping between python types and SQL types\n_PYTHON_TO_SQL = {\n    np.float64: SqlTypeName.DOUBLE,\n    pd.Float64Dtype(): SqlTypeName.DOUBLE,\n    float: SqlTypeName.FLOAT,\n    np.float32: SqlTypeName.FLOAT,\n    pd.Float32Dtype(): SqlTypeName.FLOAT,\n    np.int64: SqlTypeName.BIGINT,\n    pd.Int64Dtype(): SqlTypeName.BIGINT,\n    int: SqlTypeName.INTEGER,\n    np.int32: SqlTypeName.INTEGER,\n    pd.Int32Dtype(): SqlTypeName.INTEGER,\n    np.int16: SqlTypeName.SMALLINT,\n    pd.Int16Dtype(): SqlTypeName.SMALLINT,\n    np.int8: SqlTypeName.TINYINT,\n    pd.Int8Dtype(): SqlTypeName.TINYINT,\n    np.uint64: SqlTypeName.BIGINT,\n    pd.UInt64Dtype(): SqlTypeName.BIGINT,\n    np.uint32: SqlTypeName.INTEGER,\n    pd.UInt32Dtype(): SqlTypeName.INTEGER,\n    np.uint16: SqlTypeName.SMALLINT,\n    pd.UInt16Dtype(): SqlTypeName.SMALLINT,\n    np.uint8: SqlTypeName.TINYINT,\n    pd.UInt8Dtype(): SqlTypeName.TINYINT,\n    np.bool_: SqlTypeName.BOOLEAN,\n    pd.BooleanDtype(): SqlTypeName.BOOLEAN,\n    str: SqlTypeName.VARCHAR,\n    np.object_: SqlTypeName.VARCHAR,\n    pd.StringDtype(): SqlTypeName.VARCHAR,\n    np.datetime64: SqlTypeName.TIMESTAMP,\n}\n\n# Default mapping between SQL types and python types\n# for values\n_SQL_TO_PYTHON_SCALARS = {\n    \"SqlTypeName.DOUBLE\": np.float64,\n    \"SqlTypeName.FLOAT\": np.float32,\n    \"SqlTypeName.DECIMAL\": np.float32,\n    \"SqlTypeName.BIGINT\": np.int64,\n    \"SqlTypeName.INTEGER\": np.int32,\n    \"SqlTypeName.SMALLINT\": np.int16,\n    \"SqlTypeName.TINYINT\": np.int8,\n    \"SqlTypeName.BOOLEAN\": np.bool_,\n    \"SqlTypeName.VARCHAR\": str,\n    \"SqlTypeName.CHAR\": str,\n    \"SqlTypeName.NULL\": type(None),\n    \"SqlTypeName.SYMBOL\": lambda x: x,  # SYMBOL is a special type used for e.g. flags etc. We just keep it\n}\n\n# Default mapping between SQL types and python types\n# for data frames\n_SQL_TO_PYTHON_FRAMES = {\n    \"SqlTypeName.DOUBLE\": np.float64,\n    \"SqlTypeName.FLOAT\": np.float32,\n    \"SqlTypeName.DECIMAL\": np.float64,  # We use np.float64 always, even though we might be able to use a smaller type\n    \"SqlTypeName.BIGINT\": pd.Int64Dtype(),\n    \"SqlTypeName.INTEGER\": pd.Int32Dtype(),\n    \"SqlTypeName.SMALLINT\": pd.Int16Dtype(),\n    \"SqlTypeName.TINYINT\": pd.Int8Dtype(),\n    \"SqlTypeName.BOOLEAN\": pd.BooleanDtype(),\n    \"SqlTypeName.VARCHAR\": pd.StringDtype(),\n    \"SqlTypeName.CHAR\": pd.StringDtype(),\n    \"SqlTypeName.DATE\": np.dtype(\n        \"<M8[ns]\"\n    ),  # TODO: ideally this would be np.dtype(\"<M8[D]\") but that doesn't work for Pandas\n    \"SqlTypeName.TIME\": np.dtype(\"<M8[ns]\"),\n    \"SqlTypeName.TIMESTAMP\": np.dtype(\"<M8[ns]\"),\n    \"SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE\": pd.DatetimeTZDtype(\n        unit=\"ns\", tz=\"UTC\"\n    ),  # Everything is converted to UTC. So far, this did not break\n    \"SqlTypeName.INTERVAL_DAY\": np.dtype(\"<m8[ns]\"),\n    \"SqlTypeName.INTERVAL_MONTH_DAY_NANOSECOND\": np.dtype(\"<m8[ns]\"),\n    \"SqlTypeName.NULL\": type(None),\n}\n\n\ndef python_to_sql_type(python_type) -> \"DaskTypeMap\":\n    \"\"\"Mapping between python and SQL types.\"\"\"\n\n    if python_type in (int, float):\n        python_type = np.dtype(python_type)\n    elif python_type is str:\n        python_type = np.dtype(\"object\")\n\n    if isinstance(python_type, np.dtype):\n        python_type = python_type.type\n\n    if isinstance(python_type, pd.DatetimeTZDtype):\n        return DaskTypeMap(\n            SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE,\n            unit=str(python_type.unit),\n            tz=str(python_type.tz),\n        )\n\n    if is_decimal(python_type):\n        return DaskTypeMap(\n            SqlTypeName.DECIMAL,\n            precision=python_type.precision,\n            scale=python_type.scale,\n        )\n\n    try:\n        return DaskTypeMap(_PYTHON_TO_SQL[python_type])\n    except KeyError:  # pragma: no cover\n        raise NotImplementedError(\n            f\"The python type {python_type} is not implemented (yet)\"\n        )\n\n\ndef parse_datetime(obj):\n    formats = [\n        \"%Y-%m-%d %H:%M:%S\",\n        \"%Y-%m-%d\",\n        \"%d-%m-%Y %H:%M:%S\",\n        \"%d-%m-%Y\",\n        \"%m/%d/%Y %H:%M:%S\",\n        \"%m/%d/%Y\",\n    ]\n\n    for f in formats:\n        try:\n            datetime_obj = datetime.strptime(obj, f)\n            return datetime_obj\n        except ValueError:\n            pass\n\n    raise ValueError(\"Unable to parse datetime: \" + obj)\n\n\ndef sql_to_python_value(sql_type: \"SqlTypeName\", literal_value: Any) -> Any:\n    \"\"\"Mapping between SQL and python values (of correct type).\"\"\"\n    # In most of the cases, we turn the value first into a string.\n    # That might not be the most efficient thing to do,\n    # but works for all types (so far)\n    # Additionally, a literal type is not used\n    # so often anyways.\n\n    logger.debug(\n        f\"sql_to_python_value -> sql_type: {sql_type} literal_value: {literal_value}\"\n    )\n\n    if sql_type == SqlTypeName.CHAR or sql_type == SqlTypeName.VARCHAR:\n        # Some varchars contain an additional encoding\n        # in the format _ENCODING'string'\n        literal_value = str(literal_value)\n        if literal_value.startswith(\"_\"):\n            encoding, literal_value = literal_value.split(\"'\", 1)\n            literal_value = literal_value.rstrip(\"'\")\n            literal_value = literal_value.encode(encoding=encoding)\n            return literal_value.decode(encoding=encoding)\n\n        return literal_value\n\n    elif (\n        sql_type == SqlTypeName.DECIMAL\n        and dask_config.get(\"sql.mappings.decimal_support\") == \"cudf\"\n    ):\n        from decimal import Decimal\n\n        python_type = Decimal\n\n    elif sql_type == SqlTypeName.INTERVAL_DAY:\n        return np.timedelta64(literal_value[0], \"D\") + np.timedelta64(\n            literal_value[1], \"ms\"\n        )\n    elif sql_type == SqlTypeName.INTERVAL:\n        # check for finer granular interval types, e.g., INTERVAL MONTH, INTERVAL YEAR\n        try:\n            interval_type = str(sql_type).split()[1].lower()\n\n            if interval_type in {\"year\", \"quarter\", \"month\"}:\n                # if sql_type is INTERVAL YEAR, Calcite will covert to months\n                delta = pd.tseries.offsets.DateOffset(months=float(str(literal_value)))\n                return delta\n        except IndexError:  # pragma: no cover\n            # no finer granular interval type specified\n            pass\n        except TypeError:  # pragma: no cover\n            # interval type is not recognized, fall back to default case\n            pass\n\n        # Calcite will always convert INTERVAL types except YEAR, QUATER, MONTH to milliseconds\n        # Issue: if sql_type is INTERVAL MICROSECOND, and value <= 1000, literal_value will be rounded to 0\n        return np.timedelta64(literal_value, \"ms\")\n    elif sql_type == SqlTypeName.INTERVAL_MONTH_DAY_NANOSECOND:\n        # DataFusion assumes 30 days per month. Therefore we multiply number of months by 30 and add to days\n        return np.timedelta64(\n            (literal_value[0] * 30) + literal_value[1], \"D\"\n        ) + np.timedelta64(literal_value[2], \"ns\")\n\n    elif sql_type == SqlTypeName.BOOLEAN:\n        return bool(literal_value)\n\n    elif (\n        sql_type == SqlTypeName.TIMESTAMP\n        or sql_type == SqlTypeName.TIME\n        or sql_type == SqlTypeName.DATE\n    ):\n        if isinstance(literal_value, str):\n            literal_value = parse_datetime(literal_value)\n            literal_value = np.datetime64(literal_value)\n        elif str(literal_value) == \"None\":\n            # NULL time\n            return pd.NaT  # pragma: no cover\n        if sql_type == SqlTypeName.DATE:\n            return literal_value.astype(\"<M8[D]\")\n        return literal_value.astype(\"<M8[ns]\")\n    else:\n        try:\n            python_type = _SQL_TO_PYTHON_SCALARS[str(sql_type)]\n        except KeyError:  # pragma: no cover\n            raise NotImplementedError(\n                f\"The SQL type {sql_type} is not implemented (yet)\"\n            )\n\n    literal_value = str(literal_value)\n\n    # empty literal type. We return NaN if possible\n    if literal_value == \"None\":\n        if isinstance(python_type(), np.floating):\n            return np.NaN\n        else:\n            return pd.NA\n\n    return python_type(literal_value)\n\n\ndef sql_to_python_type(sql_type: \"SqlTypeName\", *args) -> type:\n    \"\"\"Turn an SQL type into a dataframe dtype\"\"\"\n    try:\n        if (\n            sql_type == SqlTypeName.DECIMAL\n            and dask_config.get(\"sql.mappings.decimal_support\") == \"cudf\"\n        ):\n            try:\n                import cudf\n            except ImportError:\n                raise ModuleNotFoundError(\n                    \"Setting `sql.mappings.decimal_support=cudf` requires cudf\"\n                )\n            return cudf.Decimal128Dtype(*args)\n        return _SQL_TO_PYTHON_FRAMES[str(sql_type)]\n    except KeyError:  # pragma: no cover\n        raise NotImplementedError(\n            f\"The SQL type {str(sql_type)} is not implemented (yet)\"\n        )\n\n\ndef similar_type(lhs: type, rhs: type) -> bool:\n    \"\"\"\n    Measure simularity between types.\n    Two types are similar, if they both come from the same family,\n    e.g. both are ints, uints, floats, strings etc.\n    Size or precision is not taken into account.\n\n    TODO: nullability is not checked so far.\n    \"\"\"\n    pdt = pd.api.types\n    is_uint = pdt.is_unsigned_integer_dtype\n    is_sint = pdt.is_signed_integer_dtype\n    is_float = pdt.is_float_dtype\n    is_object = pdt.is_object_dtype\n    is_string = pdt.is_string_dtype\n    is_dt_ns = pdt.is_datetime64_ns_dtype\n    is_dt_tz = lambda t: is_dt_ns(t) and isinstance(t, pd.DatetimeTZDtype)\n    is_dt_ntz = lambda t: is_dt_ns(t) and not isinstance(t, pd.DatetimeTZDtype)\n    is_td_ns = pdt.is_timedelta64_ns_dtype\n    is_bool = pdt.is_bool_dtype\n\n    checks = [\n        is_uint,\n        is_sint,\n        is_float,\n        is_object,\n        # is_string_dtype considers decimal columns to be string columns\n        lambda x: is_string(x) and not is_decimal(x),\n        is_dt_tz,\n        is_dt_ntz,\n        is_td_ns,\n        is_bool,\n        is_decimal,\n    ]\n\n    for check in checks:\n        if check(lhs) and check(rhs):\n            # check that decimal columns have equal precision/scale\n            if check is is_decimal:\n                return lhs.precision == rhs.precision and lhs.scale == rhs.scale\n            return True\n\n    return False\n\n\ndef cast_column_type(\n    df: dd.DataFrame, column_name: str, expected_type: type\n) -> dd.DataFrame:\n    \"\"\"\n    Cast the type of the given column to the expected type,\n    if they are far \"enough\" away.\n    This means, a float will never be converted into a double\n    or a tinyint into another int - but a string to an integer etc.\n    \"\"\"\n    current_type = df[column_name].dtype\n\n    logger.debug(\n        f\"Column {column_name} has type {current_type}, expecting {expected_type}...\"\n    )\n\n    casted_column = cast_column_to_type(df[column_name], expected_type)\n\n    if casted_column is not None:\n        df[column_name] = casted_column\n\n    return df\n\n\ndef cast_column_to_type(col: dd.Series, expected_type: str):\n    \"\"\"Cast the given column to the expected type\"\"\"\n    pdt = pd.api.types\n\n    is_dt_ns = pdt.is_datetime64_ns_dtype\n    is_dt_tz = lambda t: is_dt_ns(t) and isinstance(t, pd.DatetimeTZDtype)\n    is_dt_ntz = lambda t: is_dt_ns(t) and not isinstance(t, pd.DatetimeTZDtype)\n\n    current_type = col.dtype\n\n    if similar_type(current_type, expected_type):\n        logger.debug(\"...not converting.\")\n        return None\n\n    if pdt.is_integer_dtype(expected_type):\n        if pd.api.types.is_float_dtype(current_type):\n            logger.debug(\"...truncating...\")\n            # Currently \"trunc\" can not be applied to NA (the pandas missing value type),\n            # because NA is a different type. It works with np.NaN though.\n            # For our use case, that does not matter, as the conversion to integer later\n            # will convert both NA and np.NaN to NA.\n            col = da.trunc(col.fillna(value=np.NaN))\n        elif pdt.is_timedelta64_dtype(current_type):\n            logger.debug(f\"Explicitly casting from {current_type} to np.int64\")\n            return col.astype(np.int64)\n\n    if is_dt_tz(current_type) and is_dt_ntz(expected_type):\n        # casting from timezone-aware to timezone-naive datatypes with astype is deprecated in pandas 2\n        return col.dt.tz_localize(None)\n\n    logger.debug(f\"Need to cast from {current_type} to {expected_type}\")\n    return col.astype(expected_type)\n\n\ndef is_decimal(dtype):\n    \"\"\"\n    Check if dtype is a decimal type\n    \"\"\"\n    return \"decimal\" in str(dtype).lower()\n"
  },
  {
    "path": "dask_sql/physical/__init__.py",
    "content": ""
  },
  {
    "path": "dask_sql/physical/rel/__init__.py",
    "content": "from .convert import RelConverter\n"
  },
  {
    "path": "dask_sql/physical/rel/base.py",
    "content": "import logging\nfrom typing import TYPE_CHECKING, Optional\n\nimport dask.dataframe as dd\n\nfrom dask_sql.datacontainer import ColumnContainer, DataContainer\nfrom dask_sql.mappings import cast_column_type, sql_to_python_type\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql._datafusion_lib import LogicalPlan, RelDataType\n\nlogger = logging.getLogger(__name__)\n\n\nclass BaseRelPlugin:\n    \"\"\"\n    Base class for all plugins to convert between\n    a RelNode to a python expression (dask dataframe).\n\n    Derived classed needs to override the class_name attribute\n    and the convert method.\n    \"\"\"\n\n    class_name = None\n\n    def convert(self, rel: \"LogicalPlan\", context: \"dask_sql.Context\") -> dd.DataFrame:\n        \"\"\"Base method to implement\"\"\"\n        raise NotImplementedError\n\n    @staticmethod\n    def fix_column_to_row_type(\n        cc: ColumnContainer, row_type: \"RelDataType\", join_type: Optional[str] = None\n    ) -> ColumnContainer:\n        \"\"\"\n        Make sure that the given column container\n        has the column names specified by the row type.\n        We assume that the column order is already correct\n        and will just \"blindly\" rename the columns.\n        \"\"\"\n        field_names = [str(x) for x in row_type.getFieldNames()]\n        if join_type in (\"leftsemi\", \"leftanti\"):\n            field_names = field_names[: len(cc.columns)]\n\n        logger.debug(f\"Renaming {cc.columns} to {field_names}\")\n        cc = cc.rename_handle_duplicates(\n            from_columns=cc.columns, to_columns=field_names\n        )\n\n        # TODO: We can also check for the types here and do any conversions if needed\n        return cc.limit_to(field_names)\n\n    @staticmethod\n    def check_columns_from_row_type(df: dd.DataFrame, row_type: \"RelDataType\"):\n        \"\"\"\n        Similar to `self.fix_column_to_row_type`, but this time\n        check for the correct column names instead of\n        applying them.\n        \"\"\"\n        field_names = [str(x) for x in row_type.getFieldNames()]\n\n        assert list(df.columns) == field_names\n\n        # TODO: similar to self.fix_column_to_row_type, we should check for the types\n\n    @staticmethod\n    def assert_inputs(\n        rel: \"LogicalPlan\",\n        n: int = 1,\n        context: \"dask_sql.Context\" = None,\n    ) -> list[dd.DataFrame]:\n        \"\"\"\n        LogicalPlan nodes build on top of others.\n        Those are called the \"input\" of the LogicalPlan.\n        This function asserts that the given LogicalPlan has exactly as many\n        input tables as expected and returns them already\n        converted into a dask dataframe.\n        \"\"\"\n        input_rels = rel.get_inputs()\n\n        assert len(input_rels) == n\n\n        # Late import to remove cycling dependency\n        from dask_sql.physical.rel.convert import RelConverter\n\n        return [RelConverter.convert(input_rel, context) for input_rel in input_rels]\n\n    @staticmethod\n    def fix_dtype_to_row_type(\n        dc: DataContainer, row_type: \"RelDataType\", join_type: Optional[str] = None\n    ):\n        \"\"\"\n        Fix the dtype of the given data container (or: the df within it)\n        to the data type given as argument.\n        To prevent unneeded conversions, do only convert if really needed,\n        e.g. if the two types are \"similar\" enough, do not convert.\n        Similarity involves the same general type (int, float, string etc)\n        but not necessary the size (int64 and int32 are compatible)\n        or the nullability.\n        TODO: we should check the nullability of the SQL type\n        \"\"\"\n        df = dc.df\n        cc = dc.column_container\n\n        field_list = row_type.getFieldList()\n        if join_type in (\"leftsemi\", \"leftanti\"):\n            field_list = field_list[: len(cc.columns)]\n\n        field_types = {\n            str(field.getQualifiedName()): field.getType() for field in field_list\n        }\n\n        for field_name, field_type in field_types.items():\n            sql_type = field_type.getSqlType()\n            sql_type_args = tuple()\n\n            if str(sql_type) == \"SqlTypeName.DECIMAL\":\n                sql_type_args = field_type.getDataType().getPrecisionScale()\n\n            expected_type = sql_to_python_type(sql_type, *sql_type_args)\n            df_field_name = cc.get_backend_by_frontend_name(field_name)\n            df = cast_column_type(df, df_field_name, expected_type)\n\n        return DataContainer(df, dc.column_container)\n"
  },
  {
    "path": "dask_sql/physical/rel/convert.py",
    "content": "import logging\nfrom typing import TYPE_CHECKING\n\nimport dask.dataframe as dd\n\nfrom dask_sql.physical.rel.base import BaseRelPlugin\nfrom dask_sql.utils import LoggableDataFrame, Pluggable\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql._datafusion_lib import LogicalPlan\n\nlogger = logging.getLogger(__name__)\n\n\nclass RelConverter(Pluggable):\n    \"\"\"\n    Helper to convert from rel to a python expression\n\n    This class stores plugins which can convert from RelNodes to\n    python expression (typically dask dataframes).\n    The stored plugins are assumed to have a class attribute \"class_name\"\n    to control, which java classes they can convert\n    and they are expected to have a convert (instance) method\n    in the form\n\n        def convert(self, rel, context)\n\n    to do the actual conversion.\n    \"\"\"\n\n    @classmethod\n    def add_plugin_class(cls, plugin_class: BaseRelPlugin, replace=True):\n        \"\"\"Convenience function to add a class directly to the plugins\"\"\"\n        logger.debug(f\"Registering REL plugin for {plugin_class.class_name}\")\n        cls.add_plugin(plugin_class.class_name, plugin_class(), replace=replace)\n\n    @classmethod\n    def convert(cls, rel: \"LogicalPlan\", context: \"dask_sql.Context\") -> dd.DataFrame:\n        \"\"\"\n        Convert SQL AST tree node(s)\n        into a python expression (a dask dataframe)\n        using the stored plugins and the dictionary of\n        registered dask tables from the context.\n        The SQL AST tree is traversed. The context of the traversal is saved\n        in the Rust logic. We need to take that current node and determine\n        what \"type\" of Relational operator it represents to build the execution chain.\n        \"\"\"\n\n        node_type = rel.get_current_node_type()\n\n        try:\n            plugin_instance = cls.get_plugin(node_type)\n        except KeyError:  # pragma: no cover\n            raise NotImplementedError(\n                f\"No relational conversion for node type {node_type} available (yet).\"\n            )\n        logger.debug(\n            f\"Processing REL {rel} using {plugin_instance.__class__.__name__}...\"\n        )\n        df = plugin_instance.convert(rel, context=context)\n        logger.debug(f\"Processed REL {rel} into {LoggableDataFrame(df)}\")\n        return df\n"
  },
  {
    "path": "dask_sql/physical/rel/custom/__init__.py",
    "content": "from .alter import AlterSchemaPlugin, AlterTablePlugin\nfrom .analyze_table import AnalyzeTablePlugin\nfrom .create_catalog_schema import CreateCatalogSchemaPlugin\nfrom .create_experiment import CreateExperimentPlugin\nfrom .create_memory_table import CreateMemoryTablePlugin\nfrom .create_model import CreateModelPlugin\nfrom .create_table import CreateTablePlugin\nfrom .describe_model import DescribeModelPlugin\nfrom .distributeby import DistributeByPlugin\nfrom .drop_model import DropModelPlugin\nfrom .drop_schema import DropSchemaPlugin\nfrom .drop_table import DropTablePlugin\nfrom .export_model import ExportModelPlugin\nfrom .predict_model import PredictModelPlugin\nfrom .show_columns import ShowColumnsPlugin\nfrom .show_models import ShowModelsPlugin\nfrom .show_schemas import ShowSchemasPlugin\nfrom .show_tables import ShowTablesPlugin\nfrom .use_schema import UseSchemaPlugin\n\n__all__ = [\n    AnalyzeTablePlugin,\n    CreateExperimentPlugin,\n    CreateModelPlugin,\n    CreateCatalogSchemaPlugin,\n    CreateMemoryTablePlugin,\n    CreateTablePlugin,\n    DropModelPlugin,\n    DropSchemaPlugin,\n    DropTablePlugin,\n    ExportModelPlugin,\n    PredictModelPlugin,\n    ShowColumnsPlugin,\n    DescribeModelPlugin,\n    ShowModelsPlugin,\n    ShowSchemasPlugin,\n    ShowTablesPlugin,\n    UseSchemaPlugin,\n    AlterSchemaPlugin,\n    AlterTablePlugin,\n    DistributeByPlugin,\n]\n"
  },
  {
    "path": "dask_sql/physical/rel/custom/alter.py",
    "content": "import logging\nfrom typing import TYPE_CHECKING\n\nfrom dask_sql.physical.rel.base import BaseRelPlugin\n\nlogger = logging.getLogger(__name__)\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql._datafusion_lib import LogicalPlan\n\n\nclass AlterSchemaPlugin(BaseRelPlugin):\n    \"\"\"\n    Alter schema name with new name;\n\n       ALTER SCHEMA <old-schema-name> RENAME TO <new-schema-name>\n\n    Using this SQL is equivalent to just doing\n\n        context.alter_schema(<old-schema-name>,<new-schema-name>)\n\n    but can also be used without writing a single line of code.\n    Nothing is returned.\n    \"\"\"\n\n    class_name = \"AlterSchema\"\n\n    def convert(self, rel: \"LogicalPlan\", context: \"dask_sql.Context\"):\n        alter_schema = rel.alter_schema()\n\n        old_schema_name = alter_schema.getOldSchemaName()\n        new_schema_name = alter_schema.getNewSchemaName()\n\n        logger.info(\n            f\"changing schema name from `{old_schema_name}` to `{new_schema_name}`\"\n        )\n        if old_schema_name not in context.schema:\n            raise KeyError(\n                f\"Schema {old_schema_name} was not found, available schemas are - {context.schema.keys()}\"\n            )\n        context.alter_schema(\n            old_schema_name=old_schema_name, new_schema_name=new_schema_name\n        )\n\n\nclass AlterTablePlugin(BaseRelPlugin):\n    \"\"\"\n    Alter table name with new name;\n\n       ALTER TABLE [IF EXISTS] <old-table-name> RENAME TO <new-table-name>\n\n    Using this SQL is equivalent to just doing\n\n        context.alter_table(<old-table-name>,<new-table-name>)\n\n    but can also be used without writing a single line of code.\n    Nothing is returned.\n    \"\"\"\n\n    class_name = \"AlterTable\"\n\n    def convert(self, rel: \"LogicalPlan\", context: \"dask_sql.Context\"):\n        alter_table = rel.alter_table()\n\n        old_table_name = alter_table.getOldTableName()\n        new_table_name = alter_table.getNewTableName()\n        schema_name = alter_table.getSchemaName() or context.schema_name\n\n        logger.info(\n            f\"changing table name from `{old_table_name}` to `{new_table_name}`\"\n        )\n        if old_table_name not in context.schema[schema_name].tables:\n            if not alter_table.getIfExists():\n                raise KeyError(\n                    f\"Table {old_table_name} was not found, available tables in {schema_name} are \"\n                    f\"- {context.schema[schema_name].tables.keys()}\"\n                )\n            else:\n                return\n\n        context.alter_table(\n            old_table_name=old_table_name,\n            new_table_name=new_table_name,\n            schema_name=schema_name,\n        )\n"
  },
  {
    "path": "dask_sql/physical/rel/custom/analyze_table.py",
    "content": "from typing import TYPE_CHECKING\n\nimport dask.dataframe as dd\nimport pandas as pd\n\nfrom dask_sql.datacontainer import ColumnContainer, DataContainer\nfrom dask_sql.mappings import python_to_sql_type\nfrom dask_sql.physical.rel.base import BaseRelPlugin\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql._datafusion_lib import LogicalPlan\n\n\nclass AnalyzeTablePlugin(BaseRelPlugin):\n    \"\"\"\n    Show information on the table (like mean, max etc.)\n    on all or a subset of the columns..\n    The SQL is:\n\n        ANALYZE TABLE <table> COMPUTE STATISTICS FOR [ALL COLUMNS | COLUMNS a, b, ...]\n\n    The result is also a table, although it is created on the fly.\n\n    Please note: even though the syntax is very similar to e.g.\n    [the spark version](https://spark.apache.org/docs/3.0.0/sql-ref-syntax-aux-analyze-table.html),\n    this call does not help with query optimization (as the spark call would do),\n    as this is currently not implemented in dask-sql.\n    \"\"\"\n\n    class_name = \"AnalyzeTable\"\n\n    def convert(self, rel: \"LogicalPlan\", context: \"dask_sql.Context\") -> DataContainer:\n        analyze_table = rel.analyze_table()\n\n        schema_name = analyze_table.getSchemaName() or context.schema_name\n        table_name = analyze_table.getTableName()\n\n        dc = context.schema[schema_name].tables[table_name]\n        columns = analyze_table.getColumns()\n\n        if not columns:\n            columns = dc.column_container.columns\n\n        # Define some useful shortcuts\n        mapping = dc.column_container.get_backend_by_frontend_name\n        df = dc.df\n\n        # Calculate statistics\n        statistics = dd.concat(\n            [\n                df[[mapping(col) for col in columns]].describe(),\n                pd.DataFrame(\n                    {\n                        mapping(col): str(\n                            python_to_sql_type(df[mapping(col)].dtype)\n                        ).lower()\n                        for col in columns\n                    },\n                    index=[\"data_type\"],\n                ),\n                pd.DataFrame(\n                    {mapping(col): col for col in columns}, index=[\"col_name\"]\n                ),\n            ]\n        )\n\n        cc = ColumnContainer(statistics.columns)\n        dc = DataContainer(statistics, cc)\n        return dc\n"
  },
  {
    "path": "dask_sql/physical/rel/custom/create_catalog_schema.py",
    "content": "import logging\nfrom typing import TYPE_CHECKING\n\nfrom dask_sql.physical.rel.base import BaseRelPlugin\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql._datafusion_lib import LogicalPlan\n\nlogger = logging.getLogger(__name__)\n\n\nclass CreateCatalogSchemaPlugin(BaseRelPlugin):\n    \"\"\"\n    Create a schema with the given name\n    and register it at the context.\n    The SQL call looks like\n\n        CREATE SCHEMA <schema-name>\n\n    Using this SQL is equivalent to just doing\n\n        context.create_schema(<schema-name>)\n\n    but can also be used without writing a single line of code.\n    Nothing is returned.\n    \"\"\"\n\n    class_name = \"CreateCatalogSchema\"\n\n    def convert(self, rel: \"LogicalPlan\", context: \"dask_sql.Context\"):\n        create_schema = rel.create_catalog_schema()\n        schema_name = create_schema.getSchemaName()\n\n        if schema_name in context.schema:\n            if create_schema.getIfNotExists():\n                return\n            elif not create_schema.getReplace():\n                raise RuntimeError(\n                    f\"A Schema with the name {schema_name} is already present.\"\n                )\n\n        context.create_schema(schema_name)\n"
  },
  {
    "path": "dask_sql/physical/rel/custom/create_experiment.py",
    "content": "import logging\nfrom typing import TYPE_CHECKING\n\nimport dask.dataframe as dd\nimport pandas as pd\n\nfrom dask_sql.datacontainer import ColumnContainer, DataContainer\nfrom dask_sql.physical.rel.base import BaseRelPlugin\nfrom dask_sql.physical.utils.ml_classes import get_cpu_classes, get_gpu_classes\nfrom dask_sql.utils import convert_sql_kwargs, import_class, is_cudf_type\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql.rust import LogicalPlan\n\nlogger = logging.getLogger(__name__)\n\ncpu_classes = get_cpu_classes()\ngpu_classes = get_gpu_classes()\n\n\nclass CreateExperimentPlugin(BaseRelPlugin):\n    \"\"\"\n    Creates an  Experiment for hyperparameter tuning or automl like behaviour,\n    i.e evaluates models with different hyperparameters and registers the best performing\n    model in the context with the name same as experiment name,\n    which can be used for prediction\n\n    sql syntax:\n        CREATE EXPERIMENT <name> WITH ( key = value )\n            AS <some select query>\n\n    OPTIONS:\n    * model_class: Class name or full path to the class of the model to train.\n      Any sklearn, cuML, XGBoost, or LightGBM classes can be inferred\n      without the full path. In this case, models trained on cuDF dataframes\n      are automatically mapped to cuML classes, and sklearn models otherwise.\n      We map to cuML-Dask based models when possible and single-GPU cuML models otherwise.\n      Any model class with sklearn interface is valid, but might or\n      might not work well with Dask dataframes.\n      You might need to install necessary packages to use\n      the models.\n    * experiment_class : Class name or full path of the Hyperparameter tuner.\n      Any sklearn or cuML classes can be inferred\n      without the full path. In this case, models trained on cuDF dataframes\n      are automatically mapped to cuML classes, and sklearn models otherwise.\n    * tune_parameters:\n      Key-value of pairs of Hyperparameters to tune, i.e Search Space for\n      particular model to tune\n    * automl_class : Full path of the class which is sklearn compatible and\n      able to distribute work to dask clusters, currently tested with\n      tpot automl framework.\n      Refer : [Tpot example](https://examples.dask.org/machine-learning/tpot.html)\n    * target_column: Which column from the data to use as target.\n      Currently this parameter is required field, because tuning and automl\n      behaviour is implemented only for supervised algorithms.\n    * automl_kwargs:\n      Key-value pairs of arguments to be passed to automl class .\n      Refer : [Using Tpot parameters](https://epistasislab.github.io/tpot/using/)\n    * experiment_kwargs:\n      Use this parameter for passing any keyword arguments to experiment class\n    * tune_fit_kwargs:\n      Use this parameter for passing any keyword arguments to experiment.fit() method\n\n      example:\n        for Hyperparameter tuning  : (Train and evaluate same model with different parameters)\n\n            CREATE EXPERIMENT my_exp WITH(\n            model_class = 'sklearn.ensemble.GradientBoostingClassifier',\n            experiment_class = 'sklearn.model_selection.GridSearchCV',\n            tune_parameters = (n_estimators = ARRAY [16, 32, 2],\n                                learning_rate = ARRAY [0.1,0.01,0.001],\n                               max_depth = ARRAY [3,4,5,10]\n                               ),\n            target_column = 'target'\n            ) AS (\n                    SELECT x, y, x*y > 0 AS target\n                    FROM timeseries\n                    LIMIT 100\n                )\n\n       for automl : (Train different different model with different parameter)\n\n            CREATE EXPERIMENT my_exp WITH (\n            automl_class = 'tpot.TPOTClassifier',\n            automl_kwargs = (population_size = 2 ,\n            generations=2,\n            cv=2,\n            n_jobs=-1,\n            use_dask=True,\n            max_eval_time_mins=1),\n            target_column = 'target'\n            ) AS (\n                SELECT x, y, x*y > 0 AS target\n                FROM timeseries\n                LIMIT 100\n            )\n\n    \"\"\"\n\n    class_name = \"CreateExperiment\"\n\n    def convert(self, rel: \"LogicalPlan\", context: \"dask_sql.Context\") -> DataContainer:\n        create_experiment = rel.create_experiment()\n\n        select = create_experiment.getSelectQuery()\n        schema_name = create_experiment.getSchemaName() or context.schema_name\n        experiment_name = create_experiment.getExperimentName()\n        kwargs = convert_sql_kwargs(create_experiment.getSQLWithOptions())\n\n        if experiment_name in context.schema[schema_name].experiments:\n            if create_experiment.getIfNotExists():\n                return\n            elif not create_experiment.getOrReplace():\n                raise RuntimeError(\n                    f\"A experiment with the name {experiment_name} is already present.\"\n                )\n\n        logger.debug(\n            f\"Creating Experiment {experiment_name} from query {select} with options {kwargs}\"\n        )\n        model_class = None\n        automl_class = None\n        experiment_class = None\n        if \"model_class\" in kwargs:\n            model_class = kwargs.pop(\"model_class\")\n            # when model class was provided, must provide experiment_class also for tuning\n            if \"experiment_class\" not in kwargs:\n                raise ValueError(\n                    f\"Parameters must include a 'experiment_class' parameter for tuning {model_class}.\"\n                )\n            experiment_class = kwargs.pop(\"experiment_class\")\n        elif \"automl_class\" in kwargs:\n            automl_class = kwargs.pop(\"automl_class\")\n        else:\n            raise ValueError(\n                \"Parameters must include a 'model_class' or 'automl_class' parameter.\"\n            )\n        target_column = kwargs.pop(\"target_column\", \"\")\n        tune_fit_kwargs = kwargs.pop(\"tune_fit_kwargs\", {})\n        parameters = kwargs.pop(\"tune_parameters\", {})\n        experiment_kwargs = kwargs.pop(\"experiment_kwargs\", {})\n        automl_kwargs = kwargs.pop(\"automl_kwargs\", {})\n        logger.info(parameters)\n\n        training_df = context.sql(select)\n        if not target_column:\n            raise ValueError(\n                \"Unsupervised Algorithm cannot be tuned Automatically,\"\n                \"Consider providing 'target column'\"\n            )\n        non_target_columns = [\n            col for col in training_df.columns if col != target_column\n        ]\n        X = training_df[non_target_columns]\n        y = training_df[target_column]\n\n        if model_class and experiment_class:\n            if is_cudf_type(training_df):\n                model_class = gpu_classes.get(model_class, model_class)\n                experiment_class = gpu_classes.get(experiment_class, experiment_class)\n            else:\n                model_class = cpu_classes.get(model_class, model_class)\n                experiment_class = cpu_classes.get(experiment_class, experiment_class)\n\n            try:\n                ModelClass = import_class(model_class)\n            except ImportError:\n                raise ValueError(\n                    f\"Can not import model {model_class}. Make sure you spelled it correctly and have installed all packages.\"\n                )\n            try:\n                ExperimentClass = import_class(experiment_class)\n            except ImportError:\n                raise ValueError(\n                    f\"Can not import tuner {experiment_class}. Make sure you spelled it correctly and have installed all packages.\"\n                )\n\n            from dask_sql.physical.rel.custom.wrappers import ParallelPostFit\n\n            model = ModelClass()\n\n            search = ExperimentClass(model, {**parameters}, **experiment_kwargs)\n            logger.info(tune_fit_kwargs)\n            search.fit(\n                X.to_dask_array(lengths=True),\n                y.to_dask_array(lengths=True),\n                **tune_fit_kwargs,\n            )\n            df = pd.DataFrame(search.cv_results_)\n            df[\"model_class\"] = model_class\n\n            context.register_model(\n                experiment_name,\n                ParallelPostFit(estimator=search.best_estimator_),\n                X.columns,\n                schema_name=schema_name,\n            )\n\n        if automl_class:\n\n            try:\n                AutoMLClass = import_class(automl_class)\n            except ImportError:\n                raise ValueError(\n                    f\"Can not import automl model {automl_class}. Make sure you spelled it correctly and have installed all packages.\"\n                )\n\n            from dask_sql.physical.rel.custom.wrappers import ParallelPostFit\n\n            automl = AutoMLClass(**automl_kwargs)\n            # should be avoided if  data doesn't fit in memory\n            automl.fit(X.compute(), y.compute())\n            df = (\n                pd.DataFrame(automl.evaluated_individuals_)\n                .T.reset_index()\n                .rename({\"index\": \"models\"}, axis=1)\n            )\n\n            context.register_model(\n                experiment_name,\n                ParallelPostFit(estimator=automl.fitted_pipeline_),\n                X.columns,\n                schema_name=schema_name,\n            )\n\n        context.register_experiment(\n            experiment_name, experiment_results=df, schema_name=schema_name\n        )\n        cc = ColumnContainer(df.columns)\n        dc = DataContainer(dd.from_pandas(df, npartitions=1), cc)\n        return dc\n"
  },
  {
    "path": "dask_sql/physical/rel/custom/create_memory_table.py",
    "content": "import logging\nfrom typing import TYPE_CHECKING\n\nfrom dask_sql.datacontainer import DataContainer\nfrom dask_sql.physical.rel.base import BaseRelPlugin\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql._datafusion_lib import LogicalPlan\n\nlogger = logging.getLogger(__name__)\n\n\nclass CreateMemoryTablePlugin(BaseRelPlugin):\n    \"\"\"\n    Create a table or view from the given SELECT query\n    and register it at the context.\n    The SQL call looks like\n\n        CREATE TABLE <table-name> AS\n            <some select query>\n\n    It sends the select query through the normal parsing\n    and optimization and conversation before registering it.\n\n    Using this SQL is equivalent to just doing\n\n        df = context.sql(\"<select query>\")\n        context.create_table(<table-name>, df)\n\n    but can also be used without writing a single line of code.\n    Nothing is returned.\n    \"\"\"\n\n    class_name = [\"CreateMemoryTable\", \"CreateView\"]\n\n    def convert(self, rel: \"LogicalPlan\", context: \"dask_sql.Context\") -> DataContainer:\n        # Rust create_memory_table instance handle\n        create_memory_table = rel.create_memory_table()\n\n        qualified_table_name = create_memory_table.getQualifiedName()\n        *schema_name, table_name = qualified_table_name.split(\".\")\n\n        if len(schema_name) > 1:\n            raise RuntimeError(\n                f\"Expected unqualified or fully qualified table name, got {qualified_table_name}.\"\n            )\n\n        schema_name = context.schema_name if not schema_name else schema_name[0]\n\n        if schema_name not in context.schema:\n            raise RuntimeError(f\"A schema with the name {schema_name} is not present.\")\n        if table_name in context.schema[schema_name].tables:\n            if create_memory_table.getIfNotExists():\n                return\n            elif not create_memory_table.getOrReplace():\n                raise RuntimeError(\n                    f\"A table with the name {table_name} is already present.\"\n                )\n\n        input_rel = create_memory_table.getInput()\n\n        # TODO: we currently always persist for CREATE TABLE AS and never persist for CREATE VIEW AS;\n        # should this be configured by the user? https://github.com/dask-contrib/dask-sql/issues/269\n        persist = create_memory_table.isTable()\n\n        logger.debug(\n            f\"Creating new table with name {qualified_table_name} and logical plan {input_rel}\"\n        )\n\n        context.create_table(\n            table_name,\n            context._compute_table_from_rel(input_rel),\n            persist=persist,\n            schema_name=schema_name,\n        )\n"
  },
  {
    "path": "dask_sql/physical/rel/custom/create_model.py",
    "content": "import logging\nimport warnings\nfrom typing import TYPE_CHECKING\n\nimport numpy as np\nfrom dask import delayed\n\nfrom dask_sql.datacontainer import DataContainer\nfrom dask_sql.physical.rel.base import BaseRelPlugin\nfrom dask_sql.physical.utils.ml_classes import get_cpu_classes, get_gpu_classes\nfrom dask_sql.utils import convert_sql_kwargs, import_class, is_cudf_type\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql.rust import LogicalPlan\n\nlogger = logging.getLogger(__name__)\n\ncpu_classes = get_cpu_classes()\ngpu_classes = get_gpu_classes()\n\n\nclass CreateModelPlugin(BaseRelPlugin):\n    \"\"\"\n    Create and train a model on the data from the given SELECT query\n    and register it at the context.\n    The SQL call looks like\n\n        CREATE MODEL <model-name> WITH ( key = value )\n            AS <some select query>\n\n    It sends the select query through the normal parsing\n    and optimization and conversation and uses the result\n    as the training input.\n\n    The options control, how and which model is trained:\n    * model_class: Class name or full path to the class of the model to train.\n      Any sklearn, cuML, XGBoost, or LightGBM classes can be inferred\n      without the full path. In this case, models trained on cuDF dataframes\n      are automatically mapped to cuML classes, and sklearn models otherwise.\n      We map to cuML-Dask based models when possible and single-GPU cuML models otherwise.\n      Any model class with sklearn interface is valid, but might or\n      might not work well with Dask dataframes.\n      You might need to install necessary packages to use\n      the models.\n    * target_column: Which column from the data to use as target.\n      If not empty, it is removed automatically from\n      the training data. Defaults to an empty string, in which\n      case no target is feed to the model training (e.g. for\n      unsupervised algorithms). This means, you typically\n      want to set this parameter.\n    * wrap_predict: Boolean flag, whether to wrap the selected\n      model with a :class:`dask_sql.physical.rel.custom.wrappers.ParallelPostFit`.\n      Defaults to true for sklearn and single GPU cuML models and false otherwise.\n      Typically you set it to true for sklearn models if predicting on big data.\n    * wrap_fit: Boolean flag, whether to wrap the selected\n      model with a :class:`dask_sql.physical.rel.custom.wrappers.Incremental`.\n      Defaults to true for sklearn and single GPU cuML models and false otherwise.\n      Typically you set it to true for sklearn models if training on big data.\n    * fit_kwargs: keyword arguments sent to the call to fit().\n\n    All other arguments are passed to the constructor of the\n    model class.\n\n    Using this SQL is roughly equivalent to doing\n\n        df = context.sql(\"<select query>\")\n        X = df[everything except target_column]\n        y = df[target_column]\n        model = ModelClass(**kwargs)\n\n        model = model.fit(X, y, **fit_kwargs)\n        context.register_model(<model-name>, model)\n\n    but can also be used without writing a single line of code.\n    Nothing is returned.\n\n    Examples:\n\n        CREATE MODEL my_model WITH (\n            model_class = 'xgboost.XGBClassifier',\n            target_column = 'target'\n        ) AS (\n            SELECT x, y, target\n            FROM \"data\"\n        )\n\n    Notes:\n\n        This SQL call is not a 1:1 replacement for a normal\n        python training and can not fulfill all use-cases\n        or requirements!\n\n        If you are dealing with large amounts of data,\n        you might run into problems while model training and/or\n        prediction, depending if your model can cope with\n        dask dataframes.\n\n        * if you are training on relatively small amounts\n          of data but predicting on large data samples,\n          you might want to set `wrap_predict` to True.\n          With this option, model interference will be\n          parallelized/distributed.\n        * If you are training on large amounts of data,\n          you can try setting wrap_fit to True. This will\n          do the same on the training step, but works only on\n          those models, which have a `fit_partial` method.\n    \"\"\"\n\n    class_name = \"CreateModel\"\n\n    def convert(self, rel: \"LogicalPlan\", context: \"dask_sql.Context\") -> DataContainer:\n        create_model = rel.create_model()\n\n        select = create_model.getSelectQuery()\n        schema_name = create_model.getSchemaName() or context.schema_name\n        model_name = create_model.getModelName()\n        kwargs = convert_sql_kwargs(create_model.getSQLWithOptions())\n\n        if model_name in context.schema[schema_name].models:\n            if create_model.getIfNotExists():\n                return\n            elif not create_model.getOrReplace():\n                raise RuntimeError(\n                    f\"A model with the name {model_name} is already present.\"\n                )\n\n        logger.debug(\n            f\"Creating model {model_name} from query {select} with options {kwargs}\"\n        )\n\n        try:\n            model_class = kwargs.pop(\"model_class\")\n        except KeyError:\n            raise ValueError(\"Parameters must include a 'model_class' parameter.\")\n\n        target_column = kwargs.pop(\"target_column\", \"\")\n        wrap_predict = kwargs.pop(\"wrap_predict\", None)\n        wrap_fit = kwargs.pop(\"wrap_fit\", None)\n        fit_kwargs = kwargs.pop(\"fit_kwargs\", {})\n\n        if wrap_predict is False and \"dask\" not in model_class.lower():\n            warnings.warn(\n                f\"Consider using wrap_predict=True for non-Dask model {model_class}\",\n                RuntimeWarning,\n            )\n\n        training_df = context.sql(select)\n\n        if is_cudf_type(training_df):\n            model_class = gpu_classes.get(model_class, model_class)\n        else:\n            model_class = cpu_classes.get(model_class, model_class)\n\n        try:\n            ModelClass = import_class(model_class)\n        except ImportError:\n            raise ImportError(\n                f\"Failed to import model {model_class}. Make sure it is spelled correctly and the relevant packages are installed.\"\n            )\n\n        model = ModelClass(**kwargs)\n\n        if wrap_predict is None:\n            if (\n                \"sklearn\" in model_class\n                or (\"cuml\" in model_class and \"cuml.dask\" not in model_class)\n                or (\"xgboost\" in model_class and \"xgboost.dask\" not in model_class)\n            ):\n                wrap_predict = True\n            else:\n                wrap_predict = False\n        if wrap_fit is None:\n            if (\n                \"sklearn\" in model_class\n                or (\"cuml\" in model_class and \"cuml.dask\" not in model_class)\n                or (\"xgboost\" in model_class and \"xgboost.dask\" not in model_class)\n            ) and hasattr(model, \"partial_fit\"):\n                wrap_fit = True\n            else:\n                wrap_fit = False\n\n        if target_column:\n            non_target_columns = [\n                col for col in training_df.columns if col != target_column\n            ]\n            X = training_df[non_target_columns]\n            y = training_df[target_column]\n        else:\n            X = training_df\n            y = None\n\n        if wrap_fit:\n            from dask_sql.physical.rel.custom.wrappers import Incremental\n\n            model = Incremental(estimator=model)\n\n        if wrap_predict:\n            from dask_sql.physical.rel.custom.wrappers import ParallelPostFit\n\n            # When `wrap_predict` is set to True we train on single partition frames\n            # because this is only useful for non dask distributed models\n            # Training via delayed fit ensures that we dont have to transfer\n            # data back to the client for training\n\n            X_d = X.repartition(npartitions=1).to_delayed()\n            if y is not None:\n                y_d = y.repartition(npartitions=1).to_delayed()\n            else:\n                y_d = [None]\n\n            delayed_model = [delayed(model.fit)(x_p, y_p) for x_p, y_p in zip(X_d, y_d)]\n            model = delayed_model[0].compute()\n            if \"sklearn\" in model_class:\n                output_meta = np.array([])\n                model = ParallelPostFit(\n                    estimator=model,\n                    predict_meta=output_meta,\n                    predict_proba_meta=output_meta,\n                    transform_meta=output_meta,\n                )\n            else:\n                model = ParallelPostFit(estimator=model)\n\n        else:\n            model.fit(X, y, **fit_kwargs)\n        context.register_model(model_name, model, X.columns, schema_name=schema_name)\n"
  },
  {
    "path": "dask_sql/physical/rel/custom/create_table.py",
    "content": "import logging\nfrom typing import TYPE_CHECKING\n\nfrom dask_sql.datacontainer import DataContainer\nfrom dask_sql.physical.rel.base import BaseRelPlugin\nfrom dask_sql.utils import convert_sql_kwargs\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql._datafusion_lib import LogicalPlan\n\nlogger = logging.getLogger(__name__)\n\n\nclass CreateTablePlugin(BaseRelPlugin):\n    \"\"\"\n    Create a table with given parameters from already existing data\n    and register it at the context.\n    The SQL call looks like\n\n        CREATE TABLE <table-name> WITH (\n            parameter = value,\n            ...\n        )\n\n    It uses calls to \"dask.dataframe.read_<format>\"\n    where format is given by the \"format\" parameter (defaults to CSV).\n    The only mandatory parameter is the \"location\" parameter.\n\n    Using this SQL is equivalent to just doing\n\n        df = dd.read_<format>(location, **kwargs)\n        context.register_dask_dataframe(df, <table-name>)\n\n    but can also be used without writing a single line of code.\n    Nothing is returned.\n    \"\"\"\n\n    class_name = \"CreateTable\"\n\n    def convert(self, rel: \"LogicalPlan\", context: \"dask_sql.Context\") -> DataContainer:\n        create_table = rel.create_table()\n\n        schema_name = create_table.getSchemaName() or context.schema_name\n        table_name = create_table.getTableName()\n\n        if table_name in context.schema[schema_name].tables:\n            if create_table.getIfNotExists():\n                return\n            elif not create_table.getOrReplace():\n                raise RuntimeError(\n                    f\"A table with the name {table_name} is already present.\"\n                )\n\n        kwargs = convert_sql_kwargs(create_table.getSQLWithOptions())\n\n        logger.debug(\n            f\"Creating new table with name {table_name} and parameters {kwargs}\"\n        )\n\n        format = kwargs.pop(\"format\", None)\n        if format:  # pragma: no cover\n            format = format.lower()\n        persist = kwargs.pop(\"persist\", False)\n\n        try:\n            location = kwargs.pop(\"location\")\n        except KeyError:\n            raise AttributeError(\"Parameters must include a 'location' parameter.\")\n\n        gpu = kwargs.pop(\"gpu\", False)\n        context.create_table(\n            table_name,\n            location,\n            format=format,\n            persist=persist,\n            schema_name=schema_name,\n            gpu=gpu,\n            **kwargs,\n        )\n"
  },
  {
    "path": "dask_sql/physical/rel/custom/describe_model.py",
    "content": "from typing import TYPE_CHECKING\n\nimport dask.dataframe as dd\nimport pandas as pd\n\nfrom dask_sql.datacontainer import ColumnContainer, DataContainer\nfrom dask_sql.physical.rel.base import BaseRelPlugin\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql._datafusion_lib import LogicalPlan\n\n\nclass DescribeModelPlugin(BaseRelPlugin):\n    \"\"\"\n    Show all Params used to train a given model along with the columns\n    used for training.\n    The SQL is:\n\n        DESCRIBE MODEL <model_name>\n\n    The result is also a table, although it is created on the fly.\n    \"\"\"\n\n    class_name = \"DescribeModel\"\n\n    def convert(self, rel: \"LogicalPlan\", context: \"dask_sql.Context\") -> DataContainer:\n        describe_model = rel.describe_model()\n\n        schema_name = describe_model.getSchemaName() or context.schema_name\n        model_name = describe_model.getModelName()\n\n        if model_name not in context.schema[schema_name].models:\n            raise RuntimeError(f\"A model with the name {model_name} is not present.\")\n\n        model, training_columns = context.schema[schema_name].models[model_name]\n\n        model_params = model.get_params()\n        model_params[\"training_columns\"] = training_columns.tolist()\n\n        df = pd.DataFrame.from_dict(model_params, orient=\"index\", columns=[\"Params\"])\n        cc = ColumnContainer(df.columns)\n        dc = DataContainer(dd.from_pandas(df, npartitions=1), cc)\n        return dc\n"
  },
  {
    "path": "dask_sql/physical/rel/custom/distributeby.py",
    "content": "import logging\nfrom typing import TYPE_CHECKING\n\nfrom dask_sql.datacontainer import ColumnContainer, DataContainer\nfrom dask_sql.physical.rel.base import BaseRelPlugin\nfrom dask_sql.utils import LoggableDataFrame\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql._datafusion_lib import LogicalPlan\n\nlogger = logging.getLogger(__name__)\n\n\nclass DistributeByPlugin(BaseRelPlugin):\n    \"\"\"\n    Distribute the target based on the specified sql identifier from a SELECT query.\n    The SQL is:\n\n        SELECT age, name FROM person DISTRIBUTE BY age\n    \"\"\"\n\n    # DataFusion provides the phrase `Repartition` in the LogicalPlan instead of `Distribute By`, it is the same thing\n    class_name = \"Repartition\"\n\n    def convert(self, rel: \"LogicalPlan\", context: \"dask_sql.Context\") -> DataContainer:\n        distribute = rel.repartition_by()\n        select = distribute.getSelectQuery()\n        distribute_list = distribute.getDistributionColumns()\n\n        df = context.sql(select)\n        logger.debug(f\"Extracted sub-dataframe as {LoggableDataFrame(df)}\")\n\n        logger.debug(f\"Will now shuffle according to {distribute_list}\")\n\n        # Perform the distribute by operation via a Dask shuffle\n        df = df.shuffle(distribute_list)\n\n        cc = ColumnContainer(df.columns)\n        dc = DataContainer(df, cc)\n\n        return dc\n"
  },
  {
    "path": "dask_sql/physical/rel/custom/drop_model.py",
    "content": "import logging\nfrom typing import TYPE_CHECKING\n\nfrom dask_sql.datacontainer import DataContainer\nfrom dask_sql.physical.rel.base import BaseRelPlugin\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql.rust import LogicalPlan\n\nlogger = logging.getLogger(__name__)\n\n\nclass DropModelPlugin(BaseRelPlugin):\n    \"\"\"\n    Drop a model with given name.\n    The SQL call looks like\n\n        DROP MODEL <table-name>\n    \"\"\"\n\n    class_name = \"DropModel\"\n\n    def convert(self, rel: \"LogicalPlan\", context: \"dask_sql.Context\") -> DataContainer:\n        drop_model = rel.drop_model()\n\n        schema_name = drop_model.getSchemaName() or context.schema_name\n        model_name = drop_model.getModelName()\n\n        if model_name not in context.schema[schema_name].models:\n            if not drop_model.getIfExists():\n                raise RuntimeError(\n                    f\"A model with the name {model_name} is not present.\"\n                )\n            else:\n                return\n\n        del context.schema[schema_name].models[model_name]\n"
  },
  {
    "path": "dask_sql/physical/rel/custom/drop_schema.py",
    "content": "import logging\nfrom typing import TYPE_CHECKING\n\nfrom dask_sql.physical.rel.base import BaseRelPlugin\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql._datafusion_lib import LogicalPlan\n\nlogger = logging.getLogger(__name__)\n\n\nclass DropSchemaPlugin(BaseRelPlugin):\n    \"\"\"\n    Drop a schema with given name.\n    The SQL call looks like\n\n        DROP SCHEMA <schema-name>\n    \"\"\"\n\n    class_name = \"DropSchema\"\n\n    def convert(self, rel: \"LogicalPlan\", context: \"dask_sql.Context\"):\n        drop_schema = rel.drop_schema()\n        schema_name = drop_schema.getSchemaName()\n\n        if schema_name not in context.schema:\n            if not drop_schema.getIfExists():\n                raise RuntimeError(\n                    f\"A SCHEMA with the name {schema_name} is not present.\"\n                )\n            else:\n                return\n\n        context.drop_schema(schema_name)\n"
  },
  {
    "path": "dask_sql/physical/rel/custom/drop_table.py",
    "content": "import logging\nfrom typing import TYPE_CHECKING\n\nfrom dask_sql.datacontainer import DataContainer\nfrom dask_sql.physical.rel.base import BaseRelPlugin\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql.rust import LogicalPlan\n\nlogger = logging.getLogger(__name__)\n\n\nclass DropTablePlugin(BaseRelPlugin):\n    \"\"\"\n    Drop a table with given name.\n    The SQL call looks like\n\n        DROP TABLE <table-name>\n    \"\"\"\n\n    class_name = \"DropTable\"\n\n    def convert(self, rel: \"LogicalPlan\", context: \"dask_sql.Context\") -> DataContainer:\n        # Rust create_memory_table instance handle\n        drop_table = rel.drop_table()\n\n        qualified_table_name = drop_table.getQualifiedName()\n        *schema_name, table_name = qualified_table_name.split(\".\")\n\n        if len(schema_name) > 1:\n            raise RuntimeError(\n                f\"Expected unqualified or fully qualified table name, got {qualified_table_name}.\"\n            )\n\n        schema_name = context.schema_name if not schema_name else schema_name[0]\n\n        if (\n            schema_name not in context.schema\n            or table_name not in context.schema[schema_name].tables\n        ):\n            if not drop_table.getIfExists():\n                raise RuntimeError(\n                    f\"A table with the name {qualified_table_name} is not present.\"\n                )\n            else:\n                return\n\n        context.drop_table(table_name, schema_name=schema_name)\n"
  },
  {
    "path": "dask_sql/physical/rel/custom/export_model.py",
    "content": "import logging\nimport pickle\nfrom typing import TYPE_CHECKING\n\nfrom dask_sql.physical.rel.base import BaseRelPlugin\nfrom dask_sql.utils import convert_sql_kwargs\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql._datafusion_lib import LogicalPlan\n\nlogger = logging.getLogger(__name__)\n\n\nclass ExportModelPlugin(BaseRelPlugin):\n    \"\"\"\n     Export a trained model into a file using one of the supported model serialization libraries.\n\n    Sql syntax:\n        EXPORT MODEL <model_name> WTIH (\n            format = \"pickle\",\n            location = \"model.pkl\"\n        )\n    1. Most of the machine learning model framework support pickle as a serialization format\n        for example:\n            sklearn\n            Pytorch\n    2. To export a universal (framework agnostic) model, use the mlflow (https://mlflow.org/) format\n        - mlflow is a framework, which supports different flavors of model serialization, implemented\n        for different ML libraries like xgboost,catboost,lightgbm etc.\n        - A mlflow model is a self-contained artifact, which contains everything you need for\n        loading the model - without import errors\n        - To reproduce the environment, conda.yaml files are produced while saving the\n        model and stored as part of the mlflow model\n\n        NOTE:\n        - Since dask-sql expects fit-predict style model (i.e sklearn compatible model),\n            Only sklearn flavoured/sklearn subclassed models are supported as a part of mlflow serialization.\n            i.e only mlflow sklearn flavour was used for all the sklearn compatible models.\n            for example :\n                instead of using xgb.core.Booster consider using xgboost.XGBClassifier\n                since later is sklearn compatible\n    \"\"\"\n\n    class_name = \"ExportModel\"\n\n    def convert(self, rel: \"LogicalPlan\", context: \"dask_sql.Context\"):\n        export_model = rel.export_model()\n\n        schema_name = export_model.getSchemaName() or context.schema_name\n        model_name = export_model.getModelName()\n        kwargs = convert_sql_kwargs(export_model.getSQLWithOptions())\n\n        format = kwargs.pop(\"format\", \"pickle\").lower().strip()\n        location = kwargs.pop(\"location\", \"tmp.pkl\").strip()\n        try:\n            model, training_columns = context.schema[schema_name].models[model_name]\n        except KeyError:\n            raise RuntimeError(f\"A model with the name {model_name} is not present.\")\n\n        logger.info(\n            f\"Using model serde has {format} and model will be exported to {location}\"\n        )\n        if format in [\"pickle\", \"pkl\"]:\n            with open(location, \"wb\") as pkl_file:\n                pickle.dump(model, pkl_file, **kwargs)\n        elif format == \"joblib\":\n            import joblib\n\n            joblib.dump(model, location, **kwargs)\n        elif format == \"mlflow\":\n            try:\n                import mlflow\n            except ImportError:  # pragma: no cover\n                raise ImportError(\n                    \"For export in the mlflow format, you need to have mlflow installed\"\n                )\n            try:\n                import sklearn\n            except ImportError:  # pragma: no cover\n                sklearn = None\n            if sklearn is not None and isinstance(model, sklearn.base.BaseEstimator):\n                mlflow.sklearn.save_model(model, location, **kwargs)\n            else:\n                raise NotImplementedError(\n                    \"dask-sql supports only sklearn compatible model i.e fit-predict style model\"\n                )\n        elif format == \"onnx\":\n            \"\"\"\n            Need's Columns and their data type for converting\n            any model format into Onnx format, and for every framework,\n            need to install respective ONNX converters\n            \"\"\"\n            # TODO: Add support for Exporting model into ONNX format\n            raise NotImplementedError(\"ONNX format currently not supported\")\n"
  },
  {
    "path": "dask_sql/physical/rel/custom/metrics.py",
    "content": "# Copyright 2017, Dask developers\n# Dask-ML project - https://github.com/dask/dask-ml\nfrom typing import Optional, TypeVar\n\nimport dask\nimport dask.array as da\nimport numpy as np\nimport sklearn.metrics\nimport sklearn.utils.multiclass\nfrom dask.array import Array\nfrom dask.utils import derived_from\n\nArrayLike = TypeVar(\"ArrayLike\", Array, np.ndarray)\n\n\ndef accuracy_score(\n    y_true: ArrayLike,\n    y_pred: ArrayLike,\n    normalize: bool = True,\n    sample_weight: Optional[ArrayLike] = None,\n    compute: bool = True,\n) -> ArrayLike:\n    \"\"\"Accuracy classification score.\n    In multilabel classification, this function computes subset accuracy:\n    the set of labels predicted for a sample must *exactly* match the\n    corresponding set of labels in y_true.\n    Read more in the :ref:`User Guide <accuracy_score>`.\n    Parameters\n    ----------\n    y_true : 1d array-like, or label indicator array\n        Ground truth (correct) labels.\n    y_pred : 1d array-like, or label indicator array\n        Predicted labels, as returned by a classifier.\n    normalize : bool, optional (default=True)\n        If ``False``, return the number of correctly classified samples.\n        Otherwise, return the fraction of correctly classified samples.\n    sample_weight : 1d array-like, optional\n        Sample weights.\n        .. versionadded:: 0.7.3\n    Returns\n    -------\n    score : scalar dask Array\n        If ``normalize == True``, return the correctly classified samples\n        (float), else it returns the number of correctly classified samples\n        (int).\n        The best performance is 1 with ``normalize == True`` and the number\n        of samples with ``normalize == False``.\n    Notes\n    -----\n    In binary and multiclass classification, this function is equal\n    to the ``jaccard_similarity_score`` function.\n\n    \"\"\"\n\n    if y_true.ndim > 1:\n        differing_labels = ((y_true - y_pred) == 0).all(1)\n        score = differing_labels != 0\n    else:\n        score = y_true == y_pred\n\n    if normalize:\n        score = da.average(score, weights=sample_weight)\n    elif sample_weight is not None:\n        score = da.dot(score, sample_weight)\n    else:\n        score = score.sum()\n\n    if compute:\n        score = score.compute()\n    return score\n\n\ndef _log_loss_inner(\n    x: ArrayLike, y: ArrayLike, sample_weight: Optional[ArrayLike], **kwargs\n):\n    # da.map_blocks wasn't able to concatenate together the results\n    # when we reduce down to a scalar per block. So we make an\n    # array with 1 element.\n    if sample_weight is not None:\n        sample_weight = sample_weight.ravel()\n    return np.array(\n        [sklearn.metrics.log_loss(x, y, sample_weight=sample_weight, **kwargs)]\n    )\n\n\ndef log_loss(\n    y_true, y_pred, eps=1e-15, normalize=True, sample_weight=None, labels=None\n):\n    if not (dask.is_dask_collection(y_true) and dask.is_dask_collection(y_pred)):\n        return sklearn.metrics.log_loss(\n            y_true,\n            y_pred,\n            eps=eps,\n            normalize=normalize,\n            sample_weight=sample_weight,\n            labels=labels,\n        )\n\n    if y_pred.ndim > 1 and y_true.ndim == 1:\n        y_true = y_true.reshape(-1, 1)\n        drop_axis: Optional[int] = 1\n        if sample_weight is not None:\n            sample_weight = sample_weight.reshape(-1, 1)\n    else:\n        drop_axis = None\n\n    result = da.map_blocks(\n        _log_loss_inner,\n        y_true,\n        y_pred,\n        sample_weight,\n        chunks=(1,),\n        drop_axis=drop_axis,\n        dtype=\"f8\",\n        eps=eps,\n        normalize=normalize,\n        labels=labels,\n    )\n    if normalize and sample_weight is not None:\n        sample_weight = sample_weight.ravel()\n        block_weights = sample_weight.map_blocks(np.sum, chunks=(1,), keepdims=True)\n        return da.average(result, 0, weights=block_weights)\n    elif normalize:\n        return result.mean()\n    else:\n        return result.sum()\n\n\ndef _check_sample_weight(sample_weight: Optional[ArrayLike]):\n    if sample_weight is not None:\n        raise ValueError(\"'sample_weight' is not supported.\")\n\n\n@derived_from(sklearn.metrics)\ndef mean_squared_error(\n    y_true: ArrayLike,\n    y_pred: ArrayLike,\n    sample_weight: Optional[ArrayLike] = None,\n    multioutput: Optional[str] = \"uniform_average\",\n    squared: bool = True,\n    compute: bool = True,\n) -> ArrayLike:\n    _check_sample_weight(sample_weight)\n    output_errors = ((y_pred - y_true) ** 2).mean(axis=0)\n\n    if isinstance(multioutput, str) or multioutput is None:\n        if multioutput == \"raw_values\":\n            if compute:\n                return output_errors.compute()\n            else:\n                return output_errors\n    else:\n        raise ValueError(\"Weighted 'multioutput' not supported.\")\n    result = output_errors.mean()\n    if not squared:\n        result = da.sqrt(result)\n    if compute:\n        result = result.compute()\n    return result\n\n\ndef _check_reg_targets(\n    y_true: ArrayLike, y_pred: ArrayLike, multioutput: Optional[str]\n):\n    if multioutput is not None and multioutput != \"uniform_average\":\n        raise NotImplementedError(\"'multioutput' must be 'uniform_average'\")\n\n    if y_true.ndim == 1:\n        y_true = y_true.reshape((-1, 1))\n    if y_pred.ndim == 1:\n        y_pred = y_pred.reshape((-1, 1))\n\n    # TODO: y_type, multioutput\n    return None, y_true, y_pred, multioutput\n\n\n@derived_from(sklearn.metrics)\ndef r2_score(\n    y_true: ArrayLike,\n    y_pred: ArrayLike,\n    sample_weight: Optional[ArrayLike] = None,\n    multioutput: Optional[str] = \"uniform_average\",\n    compute: bool = True,\n) -> ArrayLike:\n    _check_sample_weight(sample_weight)\n    _, y_true, y_pred, _ = _check_reg_targets(y_true, y_pred, multioutput)\n    weight = 1.0\n\n    numerator = (weight * (y_true - y_pred) ** 2).sum(axis=0, dtype=\"f8\")\n    denominator = (weight * (y_true - y_true.mean(axis=0)) ** 2).sum(axis=0, dtype=\"f8\")\n\n    nonzero_denominator = denominator != 0\n    nonzero_numerator = numerator != 0\n    valid_score = nonzero_denominator & nonzero_numerator\n    output_chunks = getattr(y_true, \"chunks\", [None, None])[1]\n    output_scores = da.ones([y_true.shape[1]], chunks=output_chunks)\n    with np.errstate(all=\"ignore\"):\n        output_scores[valid_score] = 1 - (\n            numerator[valid_score] / denominator[valid_score]\n        )\n        output_scores[nonzero_numerator & ~nonzero_denominator] = 0.0\n\n    result = output_scores.mean(axis=0)\n    if compute:\n        result = result.compute()\n    return result\n"
  },
  {
    "path": "dask_sql/physical/rel/custom/predict_model.py",
    "content": "import logging\nimport uuid\nfrom typing import TYPE_CHECKING\n\nimport dask.dataframe as dd\nimport pandas as pd\n\nfrom dask_sql.datacontainer import ColumnContainer, DataContainer\nfrom dask_sql.physical.rel.base import BaseRelPlugin\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql._datafusion_lib import LogicalPlan\n\nlogger = logging.getLogger(__name__)\n\n\nclass PredictModelPlugin(BaseRelPlugin):\n    \"\"\"\n    Predict the target using the given model and dataframe from the SELECT query.\n\n    The SQL call looks like\n\n        SELECT <cols> FROM PREDICT (MODEL <model-name>, <some select query>)\n\n    The return value is the input dataframe with an additional column named\n    \"target\", which contains the predicted values.\n    The model needs to be registered at the context before using it in this function,\n    either by calling :ref:`register_model` explicitly or by training\n    a model using the `CREATE MODEL` SQL statement.\n\n    A model can be anything which has a `predict` function.\n    Please note however, that it will need to act on Dask dataframes. If you\n    are using a model not optimized for this, it might be that you run out of memory if\n    your data is larger than the RAM of a single machine.\n    To prevent this, have a look into the dask_sql.physical.rel.custom.wrappers.ParallelPostFit\n    meta-estimator. If you are using a model trained with `CREATE MODEL`\n    and the `wrap_predict` flag, this is done automatically.\n\n    Using this SQL is roughly equivalent to doing\n\n        df = context.sql(\"<select query>\")\n        model = get the model from the context\n\n        target = model.predict(df)\n        return df.assign(target=target)\n\n    but can also be used without writing a single line of code.\n    \"\"\"\n\n    class_name = \"PredictModel\"\n\n    def convert(self, rel: \"LogicalPlan\", context: \"dask_sql.Context\") -> DataContainer:\n        predict_model = rel.predict_model()\n\n        sql_select = predict_model.getSelect()\n        schema_name = predict_model.getSchemaName() or context.schema_name\n        model_name = predict_model.getModelName()\n\n        model, training_columns = context.schema[schema_name].models[model_name]\n        df = context.sql(sql_select)\n        try:\n            prediction = model.predict(df[training_columns])\n            predicted_df = df.assign(target=prediction)\n        except TypeError:\n            df = df.set_index(df.columns[0], drop=False)\n            prediction = model.predict(df[training_columns])\n            # Convert numpy.ndarray to Dask Series\n            prediction = dd.from_pandas(\n                pd.Series(prediction, index=df.index),\n                npartitions=df.npartitions,\n            )\n            predicted_df = df.assign(target=prediction)\n            # Need to drop first column to reset index\n            # because the first column is equal to the index\n            predicted_df = predicted_df.drop(columns=[df.columns[0]]).reset_index()\n\n        # Create a temporary context, which includes the\n        # new \"table\" so that we can use the normal\n        # SQL-to-dask-code machinery\n        while True:\n            # Make sure to choose a non-used name\n            temporary_table = str(uuid.uuid4())\n            if temporary_table not in context.schema[schema_name].tables:\n                break\n            else:  # pragma: no cover\n                continue\n\n        context.create_table(temporary_table, predicted_df)\n\n        cc = ColumnContainer(predicted_df.columns)\n        dc = DataContainer(predicted_df, cc)\n\n        return dc\n"
  },
  {
    "path": "dask_sql/physical/rel/custom/show_columns.py",
    "content": "from typing import TYPE_CHECKING\n\nimport dask.dataframe as dd\nimport pandas as pd\n\nfrom dask_sql.datacontainer import ColumnContainer, DataContainer\nfrom dask_sql.mappings import python_to_sql_type\nfrom dask_sql.physical.rel.base import BaseRelPlugin\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql._datafusion_lib import LogicalPlan\n\n\nclass ShowColumnsPlugin(BaseRelPlugin):\n    \"\"\"\n    Show all columns (and their types) for a given table.\n    The SQL is:\n\n        SHOW COLUMNS FROM <table>\n\n    The result is also a table, although it is created on the fly.\n    \"\"\"\n\n    class_name = \"ShowColumns\"\n\n    def convert(self, rel: \"LogicalPlan\", context: \"dask_sql.Context\") -> DataContainer:\n        show_columns = rel.show_columns()\n\n        schema_name = show_columns.getSchemaName() or context.schema_name\n        table_name = show_columns.getTableName()\n\n        dc = context.schema[schema_name].tables[table_name]\n\n        cols = dc.column_container.columns\n        dtypes = list(\n            map(\n                lambda x: str(python_to_sql_type(x)).lower(),\n                dc.df.dtypes,\n            )\n        )\n        df = pd.DataFrame(\n            {\n                \"Column\": cols,\n                \"Type\": dtypes,\n                \"Extra\": [\"\"] * len(cols),\n                \"Comment\": [\"\"] * len(cols),\n            }\n        )\n\n        cc = ColumnContainer(df.columns)\n        dc = DataContainer(dd.from_pandas(df, npartitions=1), cc)\n        return dc\n"
  },
  {
    "path": "dask_sql/physical/rel/custom/show_models.py",
    "content": "from typing import TYPE_CHECKING\n\nimport dask.dataframe as dd\nimport pandas as pd\n\nfrom dask_sql.datacontainer import ColumnContainer, DataContainer\nfrom dask_sql.physical.rel.base import BaseRelPlugin\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql._datafusion_lib import LogicalPlan\n\n\nclass ShowModelsPlugin(BaseRelPlugin):\n    \"\"\"\n    Show all MODELS currently registered/trained.\n    The SQL is:\n\n        SHOW MODELS\n\n    The result is also a table, although it is created on the fly.\n    \"\"\"\n\n    class_name = \"ShowModels\"\n\n    def convert(self, rel: \"LogicalPlan\", context: \"dask_sql.Context\") -> DataContainer:\n        schema_name = rel.show_models().getSchemaName() or context.schema_name\n\n        df = pd.DataFrame({\"Models\": list(context.schema[schema_name].models.keys())})\n\n        cc = ColumnContainer(df.columns)\n        dc = DataContainer(dd.from_pandas(df, npartitions=1), cc)\n        return dc\n"
  },
  {
    "path": "dask_sql/physical/rel/custom/show_schemas.py",
    "content": "from typing import TYPE_CHECKING\n\nimport dask.dataframe as dd\nimport pandas as pd\n\nfrom dask_sql.datacontainer import ColumnContainer, DataContainer\nfrom dask_sql.physical.rel.base import BaseRelPlugin\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql._datafusion_lib import LogicalPlan\n\n\nclass ShowSchemasPlugin(BaseRelPlugin):\n    \"\"\"\n    Show all schemas.\n    The SQL is:\n\n        SHOW SCHEMAS [FROM <catalog-name>] [LIKE <>]\n\n    The result is also a table, although it is created on the fly.\n    \"\"\"\n\n    class_name = \"ShowSchemas\"\n\n    def convert(self, rel: \"LogicalPlan\", context: \"dask_sql.Context\") -> DataContainer:\n        show_schemas = rel.show_schemas()\n\n        # \"information_schema\" is a schema which is found in every presto database\n        schemas = list(context.schema.keys())\n        schemas.append(\"information_schema\")\n        df = pd.DataFrame({\"Schema\": schemas})\n\n        # currently catalogs other than the default `dask_sql` are not supported\n        catalog_name = show_schemas.getCatalogName() or context.catalog_name\n        if catalog_name != context.catalog_name:\n            raise RuntimeError(\n                f\"A catalog with the name {catalog_name} is not present.\"\n            )\n\n        # filter by LIKE value\n        like = str(show_schemas.getLike()).strip(\"'\")\n        if like and like != \"None\":\n            df = df[df.Schema == like]\n\n        cc = ColumnContainer(df.columns)\n        dc = DataContainer(dd.from_pandas(df, npartitions=1), cc)\n        return dc\n"
  },
  {
    "path": "dask_sql/physical/rel/custom/show_tables.py",
    "content": "from typing import TYPE_CHECKING\n\nimport dask.dataframe as dd\nimport pandas as pd\n\nfrom dask_sql.datacontainer import ColumnContainer, DataContainer\nfrom dask_sql.physical.rel.base import BaseRelPlugin\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql._datafusion_lib import LogicalPlan\n\n\nclass ShowTablesPlugin(BaseRelPlugin):\n    \"\"\"\n    Show all tables currently defined for a given schema.\n    The SQL is:\n\n        SHOW TABLES FROM [<catalog>.]<schema>\n\n    Please note that dask-sql currently\n    only allows for a single schema (called \"schema\").\n\n    The result is also a table, although it is created on the fly.\n    \"\"\"\n\n    class_name = \"ShowTables\"\n\n    def convert(self, rel: \"LogicalPlan\", context: \"dask_sql.Context\") -> DataContainer:\n        show_tables = rel.show_tables()\n\n        # currently catalogs other than the default `dask_sql` are not supported\n        catalog_name = show_tables.getCatalogName() or context.catalog_name\n        if catalog_name != context.catalog_name:\n            raise RuntimeError(\n                f\"A catalog with the name {catalog_name} is not present.\"\n            )\n\n        schema_name = show_tables.getSchemaName() or context.schema_name\n\n        if schema_name not in context.schema:\n            raise AttributeError(f\"Schema {schema_name} is not defined.\")\n\n        df = pd.DataFrame({\"Table\": list(context.schema[schema_name].tables.keys())})\n\n        cc = ColumnContainer(df.columns)\n        dc = DataContainer(dd.from_pandas(df, npartitions=1), cc)\n        return dc\n"
  },
  {
    "path": "dask_sql/physical/rel/custom/use_schema.py",
    "content": "from typing import TYPE_CHECKING\n\nfrom dask_sql.datacontainer import DataContainer\nfrom dask_sql.physical.rel.base import BaseRelPlugin\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql._datafusion_lib import LogicalPlan\n\n\nclass UseSchemaPlugin(BaseRelPlugin):\n    \"\"\"\n    Show all MODELS currently registered/trained.\n    The SQL is:\n\n        SHOW MODELS\n\n    The result is also a table, although it is created on the fly.\n    \"\"\"\n\n    class_name = \"UseSchema\"\n\n    def convert(self, rel: \"LogicalPlan\", context: \"dask_sql.Context\") -> DataContainer:\n        schema_name = rel.use_schema().getSchemaName()\n\n        if schema_name in context.schema:\n            context.schema_name = schema_name\n            # set the schema on the underlying DaskSQLContext as well\n            context.context.use_schema(schema_name)\n        else:\n            raise RuntimeError(f\"Schema {schema_name} not available\")\n"
  },
  {
    "path": "dask_sql/physical/rel/custom/wrappers.py",
    "content": "# Copyright 2017, Dask developers\n# Dask-ML project - https://github.com/dask/dask-ml\n\"\"\"Meta-estimators for parallelizing estimators using the scikit-learn API.\"\"\"\nimport logging\nimport warnings\nfrom typing import Any, Callable, Union\n\nimport dask.array as da\nimport dask.dataframe as dd\nimport dask.delayed\nimport numpy as np\nimport sklearn.base\nimport sklearn.metrics\nfrom dask.delayed import Delayed\nfrom dask.highlevelgraph import HighLevelGraph\nfrom sklearn.metrics import check_scoring as sklearn_check_scoring\nfrom sklearn.metrics import make_scorer\nfrom sklearn.utils.validation import check_is_fitted\n\ntry:\n    import sklearn.base\n    import sklearn.metrics\nexcept ImportError:  # pragma: no cover\n    raise ImportError(\"sklearn must be installed\")\n\nfrom dask_sql.physical.rel.custom.metrics import (\n    accuracy_score,\n    log_loss,\n    mean_squared_error,\n    r2_score,\n)\n\nlogger = logging.getLogger(__name__)\n\n\n# Scorers\naccuracy_scorer: tuple[Any, Any] = (accuracy_score, {})\nneg_mean_squared_error_scorer = (mean_squared_error, dict(greater_is_better=False))\nr2_scorer: tuple[Any, Any] = (r2_score, {})\nneg_log_loss_scorer = (log_loss, dict(greater_is_better=False, needs_proba=True))\n\n\nSCORERS = dict(\n    accuracy=accuracy_scorer,\n    neg_mean_squared_error=neg_mean_squared_error_scorer,\n    r2=r2_scorer,\n    neg_log_loss=neg_log_loss_scorer,\n)\n\n\nclass ParallelPostFit(sklearn.base.BaseEstimator, sklearn.base.MetaEstimatorMixin):\n    \"\"\"Meta-estimator for parallel predict and transform.\n\n    Parameters\n    ----------\n    estimator : Estimator\n        The underlying estimator that is fit.\n\n    scoring : string or callable, optional\n        A single string (see :ref:`scoring_parameter`) or a callable\n        (see :ref:`scoring`) to evaluate the predictions on the test set.\n\n        For evaluating multiple metrics, either give a list of (unique)\n        strings or a dict with names as keys and callables as values.\n\n        NOTE that when using custom scorers, each scorer should return a\n        single value. Metric functions returning a list/array of values\n        can be wrapped into multiple scorers that return one value each.\n\n        See :ref:`multimetric_grid_search` for an example.\n\n        .. warning::\n\n           If None, the estimator's default scorer (if available) is used.\n           Most scikit-learn estimators will convert large Dask arrays to\n           a single NumPy array, which may exhaust the memory of your worker.\n           You probably want to always specify `scoring`.\n\n    predict_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer)\n        An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output\n        type of the estimators ``predict`` call.\n        This meta is necessary for  for some estimators to work with\n        ``dask.dataframe`` and ``dask.array``\n\n    predict_proba_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer)\n        An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output\n        type of the estimators ``predict_proba`` call.\n        This meta is necessary for  for some estimators to work with\n        ``dask.dataframe`` and ``dask.array``\n\n    transform_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer)\n        An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output\n        type of the estimators ``transform`` call.\n        This meta is necessary for  for some estimators to work with\n        ``dask.dataframe`` and ``dask.array``\n\n    \"\"\"\n\n    class_name = \"ParallelPostFit\"\n\n    def __init__(\n        self,\n        estimator=None,\n        scoring=None,\n        predict_meta=None,\n        predict_proba_meta=None,\n        transform_meta=None,\n    ):\n        self.estimator = estimator\n        self.scoring = scoring\n        self.predict_meta = predict_meta\n        self.predict_proba_meta = predict_proba_meta\n        self.transform_meta = transform_meta\n\n    def _check_array(self, X):\n        \"\"\"Validate an array for post-fit tasks.\n\n        Parameters\n        ----------\n        X : Union[Array, DataFrame]\n\n        Returns\n        -------\n        same type as 'X'\n\n        Notes\n        -----\n        The following checks are applied.\n\n        - Ensure that the array is blocked only along the samples.\n        \"\"\"\n        if isinstance(X, da.Array):\n            if X.ndim == 2 and X.numblocks[1] > 1:\n                logger.debug(\"auto-rechunking 'X'\")\n                if not np.isnan(X.chunks[0]).any():\n                    X = X.rechunk({0: \"auto\", 1: -1})\n                else:\n                    X = X.rechunk({1: -1})\n        return X\n\n    @property\n    def _postfit_estimator(self):\n        # The estimator instance to use for postfit tasks like score\n        return self.estimator\n\n    def fit(self, X, y=None, **kwargs):\n        \"\"\"Fit the underlying estimator.\n\n        Parameters\n        ----------\n        X, y : array-like\n        **kwargs\n            Additional fit-kwargs for the underlying estimator.\n\n        Returns\n        -------\n        self : object\n        \"\"\"\n        logger.info(\"Starting fit\")\n        result = self.estimator.fit(X, y, **kwargs)\n\n        # Copy over learned attributes\n        copy_learned_attributes(result, self)\n        copy_learned_attributes(result, self.estimator)\n        return self\n\n    def partial_fit(self, X, y=None, **kwargs):\n        logger.info(\"Starting partial_fit\")\n        result = self.estimator.partial_fit(X, y, **kwargs)\n\n        # Copy over learned attributes\n        copy_learned_attributes(result, self)\n        copy_learned_attributes(result, self.estimator)\n        return self\n\n    def transform(self, X):\n        \"\"\"Transform block or partition-wise for dask inputs.\n\n        For dask inputs, a dask array or dataframe is returned. For other\n        inputs (NumPy array, pandas dataframe, scipy sparse matrix), the\n        regular return value is returned.\n\n        If the underlying estimator does not have a ``transform`` method, then\n        an ``AttributeError`` is raised.\n\n        Parameters\n        ----------\n        X : array-like\n\n        Returns\n        -------\n        transformed : array-like\n        \"\"\"\n        self._check_method(\"transform\")\n        X = self._check_array(X)\n        output_meta = self.transform_meta\n\n        if isinstance(X, da.Array):\n            if output_meta is None:\n                output_meta = _get_output_dask_ar_meta_for_estimator(\n                    _transform,\n                    self._postfit_estimator,\n                    X,\n                )\n            return X.map_blocks(\n                _transform,\n                estimator=self._postfit_estimator,\n                meta=output_meta,\n            )\n        elif isinstance(X, dd.DataFrame):\n            if output_meta is None:\n                output_meta = _transform(X._meta_nonempty, self._postfit_estimator)\n            try:\n                return X.map_partitions(\n                    _transform,\n                    self._postfit_estimator,\n                    output_meta,\n                    meta=output_meta,\n                )\n            except ValueError:\n                if output_meta is None:\n                    # dask-dataframe relies on dd.core.no_default\n                    # for infering meta\n                    output_meta = dd.core.no_default\n                return X.map_partitions(\n                    _transform,\n                    estimator=self._postfit_estimator,\n                    meta=output_meta,\n                )\n        else:\n            return _transform(X, estimator=self._postfit_estimator)\n\n    def score(self, X, y, compute=True):\n        \"\"\"Returns the score on the given data.\n\n        Parameters\n        ----------\n        X : array-like, shape = [n_samples, n_features]\n            Input data, where n_samples is the number of samples and\n            n_features is the number of features.\n\n        y : array-like, shape = [n_samples] or [n_samples, n_output], optional\n            Target relative to X for classification or regression;\n            None for unsupervised learning.\n\n        Returns\n        -------\n        score : float\n                return self.estimator.score(X, y)\n        \"\"\"\n        scoring = self.scoring\n        X = self._check_array(X)\n        y = self._check_array(y)\n\n        if not scoring:\n            if type(self._postfit_estimator).score == sklearn.base.RegressorMixin.score:\n                scoring = \"r2\"\n            elif (\n                type(self._postfit_estimator).score\n                == sklearn.base.ClassifierMixin.score\n            ):\n                scoring = \"accuracy\"\n        else:\n            scoring = self.scoring\n\n        if scoring:\n            if not dask.is_dask_collection(X) and not dask.is_dask_collection(y):\n                scorer = sklearn.metrics.get_scorer(scoring)\n            else:\n                scorer = get_scorer(scoring, compute=compute)\n            return scorer(self, X, y)\n        else:\n            return self._postfit_estimator.score(X, y)\n\n    def predict(self, X):\n        \"\"\"Predict for X.\n\n        For dask inputs, a dask array or dataframe is returned. For other\n        inputs (NumPy array, pandas dataframe, scipy sparse matrix), the\n        regular return value is returned.\n\n        Parameters\n        ----------\n        X : array-like\n\n        Returns\n        -------\n        y : array-like\n        \"\"\"\n        self._check_method(\"predict\")\n        X = self._check_array(X)\n        output_meta = self.predict_meta\n\n        if isinstance(X, da.Array):\n            if output_meta is None:\n                output_meta = _get_output_dask_ar_meta_for_estimator(\n                    _predict, self._postfit_estimator, X\n                )\n\n            result = X.map_blocks(\n                _predict,\n                estimator=self._postfit_estimator,\n                drop_axis=1,\n                meta=output_meta,\n            )\n            return result\n\n        elif isinstance(X, dd.DataFrame):\n            if output_meta is None:\n                # dask-dataframe relies on dd.core.no_default\n                # for infering meta\n                output_meta = _predict(X._meta_nonempty, self._postfit_estimator)\n            try:\n                return X.map_partitions(\n                    _predict,\n                    self._postfit_estimator,\n                    output_meta,\n                    meta=output_meta,\n                )\n            except ValueError:\n                if output_meta is None:\n                    output_meta = dd.core.no_default\n                return X.map_partitions(\n                    _predict,\n                    estimator=self._postfit_estimator,\n                    meta=output_meta,\n                )\n        else:\n            return _predict(X, estimator=self._postfit_estimator)\n\n    def predict_proba(self, X):\n        \"\"\"Probability estimates.\n\n        For dask inputs, a dask array or dataframe is returned. For other\n        inputs (NumPy array, pandas dataframe, scipy sparse matrix), the\n        regular return value is returned.\n\n        If the underlying estimator does not have a ``predict_proba``\n        method, then an ``AttributeError`` is raised.\n\n        Parameters\n        ----------\n        X : array or dataframe\n\n        Returns\n        -------\n        y : array-like\n        \"\"\"\n        X = self._check_array(X)\n\n        self._check_method(\"predict_proba\")\n\n        output_meta = self.predict_proba_meta\n\n        if isinstance(X, da.Array):\n            if output_meta is None:\n                output_meta = _get_output_dask_ar_meta_for_estimator(\n                    _predict_proba, self._postfit_estimator, X\n                )\n            # XXX: multiclass\n            return X.map_blocks(\n                _predict_proba,\n                estimator=self._postfit_estimator,\n                meta=output_meta,\n                chunks=(X.chunks[0], len(self._postfit_estimator.classes_)),\n            )\n        elif isinstance(X, dd.DataFrame):\n            if output_meta is None:\n                # dask-dataframe relies on dd.core.no_default\n                # for infering meta\n                output_meta = _predict_proba(X._meta_nonempty, self._postfit_estimator)\n            try:\n                return X.map_partitions(\n                    _predict_proba,\n                    self._postfit_estimator,\n                    output_meta,\n                    meta=output_meta,\n                )\n            except ValueError:\n                if output_meta is None:\n                    output_meta = dd.core.no_default\n                return X.map_partitions(\n                    _predict_proba, estimator=self._postfit_estimator, meta=output_meta\n                )\n        else:\n            return _predict_proba(X, estimator=self._postfit_estimator)\n\n    def predict_log_proba(self, X):\n        \"\"\"Log of probability estimates.\n\n        For dask inputs, a dask array or dataframe is returned. For other\n        inputs (NumPy array, pandas dataframe, scipy sparse matrix), the\n        regular return value is returned.\n\n        If the underlying estimator does not have a ``predict_proba``\n        method, then an ``AttributeError`` is raised.\n\n        Parameters\n        ----------\n        X : array or dataframe\n\n        Returns\n        -------\n        y : array-like\n        \"\"\"\n        self._check_method(\"predict_log_proba\")\n        return da.log(self.predict_proba(X))\n\n    def _check_method(self, method):\n        \"\"\"Check if self.estimator has 'method'.\n\n        Raises\n        ------\n        AttributeError\n        \"\"\"\n        estimator = self._postfit_estimator\n        if not hasattr(estimator, method):\n            msg = \"The wrapped estimator '{}' does not have a '{}' method.\".format(\n                estimator, method\n            )\n            raise AttributeError(msg)\n        return getattr(estimator, method)\n\n\nclass Incremental(ParallelPostFit):\n    \"\"\"Metaestimator for feeding Dask Arrays to an estimator blockwise.\n    This wrapper provides a bridge between Dask objects and estimators\n    implementing the ``partial_fit`` API. These *incremental learners* can\n    train on batches of data. This fits well with Dask's blocked data\n    structures.\n    .. note::\n       This meta-estimator is not appropriate for hyperparameter optimization\n       on larger-than-memory datasets.\n    See the `list of incremental learners`_ in the scikit-learn documentation\n    for a list of estimators that implement the ``partial_fit`` API. Note that\n    `Incremental` is not limited to just these classes, it will work on any\n    estimator implementing ``partial_fit``, including those defined outside of\n    scikit-learn itself.\n    Calling :meth:`Incremental.fit` with a Dask Array will pass each block of\n    the Dask array or arrays to ``estimator.partial_fit`` *sequentially*.\n    Like :class:`ParallelPostFit`, the methods available after fitting (e.g.\n    :meth:`Incremental.predict`, etc.) are all parallel and delayed.\n    The ``estimator_`` attribute is a clone of `estimator` that was actually\n    used during the call to ``fit``. All attributes learned during training\n    are available on ``Incremental`` directly.\n    .. _list of incremental learners: https://scikit-learn.org/stable/modules/computing.html#incremental-learning  # noqa\n    Parameters\n    ----------\n    estimator : Estimator\n        Any object supporting the scikit-learn ``partial_fit`` API.\n    scoring : string or callable, optional\n        A single string (see :ref:`scoring_parameter`) or a callable\n        (see :ref:`scoring`) to evaluate the predictions on the test set.\n        For evaluating multiple metrics, either give a list of (unique)\n        strings or a dict with names as keys and callables as values.\n        NOTE that when using custom scorers, each scorer should return a\n        single value. Metric functions returning a list/array of values\n        can be wrapped into multiple scorers that return one value each.\n        See :ref:`multimetric_grid_search` for an example.\n        .. warning::\n           If None, the estimator's default scorer (if available) is used.\n           Most scikit-learn estimators will convert large Dask arrays to\n           a single NumPy array, which may exhaust the memory of your worker.\n           You probably want to always specify `scoring`.\n    random_state : int or numpy.random.RandomState, optional\n        Random object that determines how to shuffle blocks.\n    shuffle_blocks : bool, default True\n        Determines whether to call ``partial_fit`` on a randomly selected chunk\n        of the Dask arrays (default), or to fit in sequential order. This does\n        not control shuffle between blocks or shuffling each block.\n    predict_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer)\n        An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output\n        type of the estimators ``predict`` call.\n        This meta is necessary for  for some estimators to work with\n        ``dask.dataframe`` and ``dask.array``\n    predict_proba_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer)\n        An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output\n        type of the estimators ``predict_proba`` call.\n        This meta is necessary for  for some estimators to work with\n        ``dask.dataframe`` and ``dask.array``\n    transform_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer)\n        An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output\n        type of the estimators ``transform`` call.\n        This meta is necessary for  for some estimators to work with\n        ``dask.dataframe`` and ``dask.array``\n    Attributes\n    ----------\n    estimator_ : Estimator\n        A clone of `estimator` that was actually fit during the ``.fit`` call.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        estimator=None,\n        scoring=None,\n        shuffle_blocks=True,\n        random_state=None,\n        assume_equal_chunks=True,\n        predict_meta=None,\n        predict_proba_meta=None,\n        transform_meta=None,\n    ):\n        self.shuffle_blocks = shuffle_blocks\n        self.random_state = random_state\n        self.assume_equal_chunks = assume_equal_chunks\n        super().__init__(\n            estimator=estimator,\n            scoring=scoring,\n            predict_meta=predict_meta,\n            predict_proba_meta=predict_proba_meta,\n            transform_meta=transform_meta,\n        )\n\n    @property\n    def _postfit_estimator(self):\n        check_is_fitted(self, \"estimator_\")\n        return self.estimator_\n\n    def _fit_for_estimator(self, estimator, X, y, **fit_kwargs):\n        check_scoring(estimator, self.scoring)\n        if not dask.is_dask_collection(X) and not dask.is_dask_collection(y):\n            try:\n                result = estimator.partial_fit(X=X, y=y, **fit_kwargs)\n            except ValueError:\n                result = estimator.partial_fit(\n                    X=X, y=y, classes=np.unique(y), **fit_kwargs\n                )\n        else:\n            result = fit(\n                estimator,\n                X,\n                y,\n                random_state=self.random_state,\n                shuffle_blocks=self.shuffle_blocks,\n                assume_equal_chunks=self.assume_equal_chunks,\n                **fit_kwargs,\n            )\n\n        copy_learned_attributes(result, self)\n        self.estimator_ = result\n        return self\n\n    def fit(self, X, y=None, **fit_kwargs):\n        estimator = sklearn.base.clone(self.estimator)\n        self._fit_for_estimator(estimator, X, y, **fit_kwargs)\n        return self\n\n    def partial_fit(self, X, y=None, **fit_kwargs):\n        \"\"\"Fit the underlying estimator.\n        If this estimator has not been previously fit, this is identical to\n        :meth:`Incremental.fit`. If it has been previously fit,\n        ``self.estimator_`` is used as the starting point.\n        Parameters\n        ----------\n        X, y : array-like\n        **kwargs\n            Additional fit-kwargs for the underlying estimator.\n        Returns\n        -------\n        self : object\n        \"\"\"\n        estimator = getattr(self, \"estimator_\", None)\n        if estimator is None:\n            estimator = sklearn.base.clone(self.estimator)\n        return self._fit_for_estimator(estimator, X, y, **fit_kwargs)\n\n\ndef handle_empty_partitions(output_meta):\n    if hasattr(output_meta, \"__array_function__\"):\n        if len(output_meta.shape) == 1:\n            shape = 0\n        else:\n            shape = list(output_meta.shape)\n            shape[0] = 0\n        ar = np.zeros(\n            shape=shape,\n            dtype=output_meta.dtype,\n            like=output_meta,\n        )\n        return ar\n    elif \"scipy.sparse\" in type(output_meta).__module__:\n        # sparse matrices don't support\n        # `like` due to non implemented __array_function__\n        # Refer https://github.com/scipy/scipy/issues/10362\n        # Note below works for both cupy and scipy sparse matrices\n        if len(output_meta.shape) == 1:\n            shape = 0\n        else:\n            shape = list(output_meta.shape)\n            shape[0] = 0\n        ar = type(output_meta)(shape, dtype=output_meta.dtype)\n        return ar\n    elif hasattr(output_meta, \"iloc\"):\n        return output_meta.iloc[:0, :]\n\n\ndef _predict(part, estimator, output_meta=None):\n    if part.shape[0] == 0 and output_meta is not None:\n        empty_output = handle_empty_partitions(output_meta)\n        if empty_output is not None:\n            return empty_output\n    return estimator.predict(part)\n\n\ndef _predict_proba(part, estimator, output_meta=None):\n    if part.shape[0] == 0 and output_meta is not None:\n        empty_output = handle_empty_partitions(output_meta)\n        if empty_output is not None:\n            return empty_output\n    return estimator.predict_proba(part)\n\n\ndef _transform(part, estimator, output_meta=None):\n    if part.shape[0] == 0 and output_meta is not None:\n        empty_output = handle_empty_partitions(output_meta)\n        if empty_output is not None:\n            return empty_output\n    return estimator.transform(part)\n\n\ndef _get_output_dask_ar_meta_for_estimator(model_fn, estimator, input_dask_ar):\n    \"\"\"\n    Returns the output metadata array\n    for the model function (predict, transform etc)\n    by running the appropriate function on dummy data\n    of shape (1, n_features)\n\n    Parameters\n    ----------\n\n    model_fun: Model function\n        _predict, _transform etc\n\n    estimator : Estimator\n        The underlying estimator that is fit.\n\n    input_dask_ar: The input dask_array\n\n    Returns\n    -------\n    metadata: metadata of  output dask array\n\n    \"\"\"\n    # sklearn fails if input array has size size\n    # It requires at least 1 sample to run successfully\n    input_meta = input_dask_ar._meta\n    if hasattr(input_meta, \"__array_function__\"):\n        ar = np.zeros(\n            shape=(1, input_dask_ar.shape[1]),\n            dtype=input_dask_ar.dtype,\n            like=input_meta,\n        )\n    elif \"scipy.sparse\" in type(input_meta).__module__:\n        # sparse matrices dont support\n        # `like` due to non implimented __array_function__\n        # Refer https://github.com/scipy/scipy/issues/10362\n        # Note below works for both cupy and scipy sparse matrices\n        ar = type(input_meta)((1, input_dask_ar.shape[1]), dtype=input_dask_ar.dtype)\n    else:\n        func_name = model_fn.__name__.strip(\"_\")\n        msg = (\n            f\"Metadata for {func_name} is not provided, so Dask is \"\n            f\"running the {func_name} \"\n            \"function on a small dataset to guess output metadata. \"\n            \"As a result, It is possible that Dask will guess incorrectly.\"\n        )\n        warnings.warn(msg)\n        ar = np.zeros(shape=(1, input_dask_ar.shape[1]), dtype=input_dask_ar.dtype)\n    return model_fn(ar, estimator)\n\n\ndef copy_learned_attributes(from_estimator, to_estimator):\n    attrs = {k: v for k, v in vars(from_estimator).items() if k.endswith(\"_\")}\n\n    for k, v in attrs.items():\n        setattr(to_estimator, k, v)\n\n\ndef get_scorer(scoring: Union[str, Callable], compute: bool = True) -> Callable:\n    \"\"\"Get a scorer from string\n    Parameters\n    ----------\n    scoring : str | callable\n        scoring method as string. If callable it is returned as is.\n    Returns\n    -------\n    scorer : callable\n        The scorer.\n    \"\"\"\n    # This is the same as sklearns, only we use our SCORERS dict,\n    # and don't have back-compat code\n    if isinstance(scoring, str):\n        try:\n            scorer, kwargs = SCORERS[scoring]\n        except KeyError:\n            raise ValueError(\n                \"{} is not a valid scoring value. \"\n                \"Valid options are {}\".format(scoring, sorted(SCORERS))\n            )\n    else:\n        scorer = scoring\n        kwargs = {}\n\n    kwargs[\"compute\"] = compute\n\n    return make_scorer(scorer, **kwargs)\n\n\ndef check_scoring(estimator, scoring=None, **kwargs):\n    res = sklearn_check_scoring(estimator, scoring=scoring, **kwargs)\n    if scoring in SCORERS.keys():\n        func, kwargs = SCORERS[scoring]\n        return make_scorer(func, **kwargs)\n    return res\n\n\ndef fit(\n    model,\n    x,\n    y,\n    compute=True,\n    shuffle_blocks=True,\n    random_state=None,\n    assume_equal_chunks=False,\n    **kwargs,\n):\n    \"\"\"Fit scikit learn model against dask arrays\n    Model must support the ``partial_fit`` interface for online or batch\n    learning.\n    Ideally your rows are independent and identically distributed. By default,\n    this function will step through chunks of the arrays in random order.\n    Parameters\n    ----------\n    model: sklearn model\n        Any model supporting partial_fit interface\n    x: dask Array\n        Two dimensional array, likely tall and skinny\n    y: dask Array\n        One dimensional array with same chunks as x's rows\n    compute : bool\n        Whether to compute this result\n    shuffle_blocks : bool\n        Whether to shuffle the blocks with ``random_state`` or not\n    random_state : int or numpy.random.RandomState\n        Random state to use when shuffling blocks\n    kwargs:\n        options to pass to partial_fit\n    \"\"\"\n\n    nblocks, x_name = _blocks_and_name(x)\n    if y is not None:\n        y_nblocks, y_name = _blocks_and_name(y)\n        assert y_nblocks == nblocks\n    else:\n        y_name = \"\"\n\n    if not hasattr(model, \"partial_fit\"):\n        msg = \"The class '{}' does not implement 'partial_fit'.\"\n        raise ValueError(msg.format(type(model)))\n\n    order = list(range(nblocks))\n    if shuffle_blocks:\n        rng = sklearn.utils.check_random_state(random_state)\n        rng.shuffle(order)\n\n    name = \"fit-\" + dask.base.tokenize(model, x, y, kwargs, order)\n\n    if hasattr(x, \"chunks\") and x.ndim > 1:\n        x_extra = (0,)\n    else:\n        x_extra = ()\n\n    dsk = {(name, -1): model}\n    dsk.update(\n        {\n            (name, i): (\n                _partial_fit,\n                (name, i - 1),\n                (x_name, order[i]) + x_extra,\n                (y_name, order[i]),\n                kwargs,\n            )\n            for i in range(nblocks)\n        }\n    )\n\n    dependencies = [x]\n    if y is not None:\n        dependencies.append(y)\n    new_dsk = HighLevelGraph.from_collections(name, dsk, dependencies=dependencies)\n    value = Delayed((name, nblocks - 1), new_dsk, layer=name)\n\n    if compute:\n        return value.compute()\n    else:\n        return value\n\n\ndef _blocks_and_name(obj):\n    if hasattr(obj, \"chunks\"):\n        nblocks = len(obj.chunks[0])\n        name = obj.name\n\n    elif hasattr(obj, \"npartitions\"):\n        # dataframe, bag\n        nblocks = obj.npartitions\n        if hasattr(obj, \"_name\"):\n            # dataframe\n            name = obj._name\n        else:\n            # bag\n            name = obj.name\n\n    return nblocks, name\n\n\ndef _partial_fit(model, x, y, kwargs=None):\n    kwargs = kwargs or dict()\n    model.partial_fit(x, y, **kwargs)\n    return model\n"
  },
  {
    "path": "dask_sql/physical/rel/logical/__init__.py",
    "content": "from .aggregate import DaskAggregatePlugin\nfrom .cross_join import DaskCrossJoinPlugin\nfrom .empty import DaskEmptyRelationPlugin\nfrom .explain import ExplainPlugin\nfrom .filter import DaskFilterPlugin\nfrom .join import DaskJoinPlugin\nfrom .limit import DaskLimitPlugin\nfrom .project import DaskProjectPlugin\nfrom .sample import SamplePlugin\nfrom .sort import DaskSortPlugin\nfrom .subquery_alias import SubqueryAlias\nfrom .table_scan import DaskTableScanPlugin\nfrom .union import DaskUnionPlugin\nfrom .values import DaskValuesPlugin\nfrom .window import DaskWindowPlugin\n\n__all__ = [\n    DaskAggregatePlugin,\n    DaskEmptyRelationPlugin,\n    DaskFilterPlugin,\n    DaskJoinPlugin,\n    DaskCrossJoinPlugin,\n    DaskLimitPlugin,\n    DaskProjectPlugin,\n    DaskSortPlugin,\n    DaskTableScanPlugin,\n    DaskUnionPlugin,\n    DaskValuesPlugin,\n    DaskWindowPlugin,\n    SamplePlugin,\n    ExplainPlugin,\n    SubqueryAlias,\n]\n"
  },
  {
    "path": "dask_sql/physical/rel/logical/aggregate.py",
    "content": "import logging\nimport operator\nfrom collections import defaultdict\nfrom functools import reduce\nfrom typing import TYPE_CHECKING, Any, Callable\n\nimport dask.dataframe as dd\nimport pandas as pd\nfrom dask import config as dask_config\n\nfrom dask_sql.datacontainer import ColumnContainer, DataContainer\nfrom dask_sql.physical.rel.base import BaseRelPlugin\nfrom dask_sql.physical.rex.convert import RexConverter\nfrom dask_sql.physical.rex.core.call import IsNullOperation\nfrom dask_sql.utils import is_cudf_type, new_temporary_column\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql._datafusion_lib import LogicalPlan\n\nlogger = logging.getLogger(__name__)\n\n\nclass ReduceAggregation(dd.Aggregation):\n    \"\"\"\n    A special form of an aggregation, that applies a given operation\n    on all elements in a group with \"reduce\".\n    \"\"\"\n\n    def __init__(self, name: str, operation: Callable):\n        series_aggregate = lambda s: s.aggregate(lambda x: reduce(operation, x))\n\n        super().__init__(name, series_aggregate, series_aggregate)\n\n\nclass AggregationOnPandas(dd.Aggregation):\n    \"\"\"\n    A special form of an aggregation, which does not apply the given function\n    (given as attribute name) directly to the dask groupby, but\n    via the groupby().apply() method. This is needed to call\n    functions directly on the pandas dataframes, but should be done\n    very carefully (as it is a performance bottleneck).\n    \"\"\"\n\n    def __init__(self, function_name: str):\n        def _f(s):\n            return s.apply(lambda s0: getattr(s0.dropna(), function_name)())\n\n        super().__init__(function_name, _f, _f)\n\n\nclass AggregationSpecification:\n    \"\"\"\n    Most of the aggregations in SQL are already\n    implemented 1:1 in dask and can just be called via their name\n    (e.g. AVG is the mean). However sometimes those\n    implemented functions only work well for some datatypes.\n    This small container class therefore\n    can have an custom aggregation function, which is\n    valid for not supported dtypes.\n    \"\"\"\n\n    def __init__(self, built_in_aggregation, custom_aggregation=None):\n        self.built_in_aggregation = built_in_aggregation\n        self.custom_aggregation = custom_aggregation or built_in_aggregation\n\n    def get_supported_aggregation(self, series):\n        built_in_aggregation = self.built_in_aggregation\n\n        # built-in aggregations work well for numeric types\n        if pd.api.types.is_numeric_dtype(series.dtype):\n            return built_in_aggregation\n\n        # Todo: Add Categorical when support comes to dask-sql\n        if built_in_aggregation in [\"min\", \"max\"]:\n            if pd.api.types.is_datetime64_any_dtype(series.dtype):\n                return built_in_aggregation\n\n            if pd.api.types.is_string_dtype(series.dtype):\n                # If dask_cudf strings dtype, return built-in aggregation\n                if is_cudf_type(series):\n                    return built_in_aggregation\n\n                # with pandas StringDtype built-in aggregations work\n                if isinstance(series.dtype, pd.StringDtype):\n                    return built_in_aggregation\n\n        return self.custom_aggregation\n\n\nclass DaskAggregatePlugin(BaseRelPlugin):\n    \"\"\"\n    A DaskAggregate is used in GROUP BY clauses, but also\n    when aggregating a function over the full dataset.\n\n    In the first case we need to find out which columns we need to\n    group over, in the second case we \"cheat\" and add a 1-column\n    to the dataframe, which allows us to reuse every aggregation\n    function we already know of.\n    As NULLs are not groupable in dask, we handle them special\n    by adding a temporary column which is True for all NULL values\n    and False otherwise (and also group by it).\n\n    The rest is just a lot of column-name-bookkeeping.\n    Fortunately calcite will already make sure, that each\n    aggregation function will only every be called with a single input\n    column (by splitting the inner calculation to a step before).\n\n    Open TODO: So far we are following the dask default\n    to only have a single partition after the group by (which is usual\n    a reasonable assumption). It would be nice to control\n    these things via HINTs.\n    \"\"\"\n\n    class_name = [\"Aggregate\", \"Distinct\"]\n\n    AGGREGATION_MAPPING = {\n        \"sum\": AggregationSpecification(\"sum\", AggregationOnPandas(\"sum\")),\n        \"$sum0\": AggregationSpecification(\"sum\", AggregationOnPandas(\"sum\")),\n        \"any_value\": AggregationSpecification(\n            dd.Aggregation(\n                \"any_value\",\n                lambda s: s.sample(n=1).values,\n                lambda s0: s0.sample(n=1).values,\n            )\n        ),\n        \"avg\": AggregationSpecification(\"mean\", AggregationOnPandas(\"mean\")),\n        \"stddev\": AggregationSpecification(\"std\", AggregationOnPandas(\"std\")),\n        \"stddevsamp\": AggregationSpecification(\"std\", AggregationOnPandas(\"std\")),\n        \"stddev_samp\": AggregationSpecification(\"std\", AggregationOnPandas(\"std\")),\n        \"stddevpop\": AggregationSpecification(\n            dd.Aggregation(\n                \"stddevpop\",\n                lambda s: (s.count(), s.sum(), s.agg(lambda x: (x**2).sum())),\n                lambda count, sum, sum_of_squares: (\n                    count.sum(),\n                    sum.sum(),\n                    sum_of_squares.sum(),\n                ),\n                lambda count, sum, sum_of_squares: (\n                    (sum_of_squares / count) - (sum / count) ** 2\n                )\n                ** (1 / 2),\n            )\n        ),\n        \"stddev_pop\": AggregationSpecification(\n            dd.Aggregation(\n                \"stddev_pop\",\n                lambda s: (s.count(), s.sum(), s.agg(lambda x: (x**2).sum())),\n                lambda count, sum, sum_of_squares: (\n                    count.sum(),\n                    sum.sum(),\n                    sum_of_squares.sum(),\n                ),\n                lambda count, sum, sum_of_squares: (\n                    (sum_of_squares / count) - (sum / count) ** 2\n                )\n                ** (1 / 2),\n            )\n        ),\n        \"bit_and\": AggregationSpecification(\n            ReduceAggregation(\"bit_and\", operator.and_)\n        ),\n        \"bit_or\": AggregationSpecification(ReduceAggregation(\"bit_or\", operator.or_)),\n        \"bit_xor\": AggregationSpecification(ReduceAggregation(\"bit_xor\", operator.xor)),\n        \"count\": AggregationSpecification(\"count\"),\n        \"every\": AggregationSpecification(\n            dd.Aggregation(\"every\", lambda s: s.all(), lambda s0: s0.all())\n        ),\n        \"max\": AggregationSpecification(\"max\", AggregationOnPandas(\"max\")),\n        \"min\": AggregationSpecification(\"min\", AggregationOnPandas(\"min\")),\n        \"single_value\": AggregationSpecification(\"first\"),\n        # is null was checked earlier, now only need to compute the sum the non null values\n        \"regr_count\": AggregationSpecification(\"sum\", AggregationOnPandas(\"sum\")),\n        \"regr_syy\": AggregationSpecification(\n            dd.Aggregation(\n                \"regr_syy\",\n                lambda s: (s.count(), s.sum(), s.agg(lambda x: (x**2).sum())),\n                lambda count, sum, sum_of_squares: (\n                    count.sum(),\n                    sum.sum(),\n                    sum_of_squares.sum(),\n                ),\n                lambda count, sum, sum_of_squares: (\n                    sum_of_squares - (sum * (sum / count))\n                ),\n            )\n        ),\n        \"regr_sxx\": AggregationSpecification(\n            dd.Aggregation(\n                \"regr_sxx\",\n                lambda s: (s.count(), s.sum(), s.agg(lambda x: (x**2).sum())),\n                lambda count, sum, sum_of_squares: (\n                    count.sum(),\n                    sum.sum(),\n                    sum_of_squares.sum(),\n                ),\n                lambda count, sum, sum_of_squares: (\n                    sum_of_squares - (sum * (sum / count))\n                ),\n            )\n        ),\n        \"variancepop\": AggregationSpecification(\n            dd.Aggregation(\n                \"variancepop\",\n                lambda s: (s.count(), s.sum(), s.agg(lambda x: (x**2).sum())),\n                lambda count, sum, sum_of_squares: (\n                    count.sum(),\n                    sum.sum(),\n                    sum_of_squares.sum(),\n                ),\n                lambda count, sum, sum_of_squares: (\n                    (sum_of_squares / count) - (sum / count) ** 2\n                ),\n            )\n        ),\n        \"variance_pop\": AggregationSpecification(\n            dd.Aggregation(\n                \"variance_pop\",\n                lambda s: (s.count(), s.sum(), s.agg(lambda x: (x**2).sum())),\n                lambda count, sum, sum_of_squares: (\n                    count.sum(),\n                    sum.sum(),\n                    sum_of_squares.sum(),\n                ),\n                lambda count, sum, sum_of_squares: (\n                    (sum_of_squares / count) - (sum / count) ** 2\n                ),\n            )\n        ),\n    }\n\n    def convert(self, rel: \"LogicalPlan\", context: \"dask_sql.Context\") -> DataContainer:\n        (dc,) = self.assert_inputs(rel, 1, context)\n\n        agg = rel.aggregate()\n\n        df = dc.df\n        cc = dc.column_container\n\n        # We make our life easier with having unique column names\n        cc = cc.make_unique()\n\n        group_exprs = agg.getGroupSets()\n        group_columns = (\n            agg.getDistinctColumns()\n            if agg.isDistinctNode()\n            else [group_expr.column_name(rel) for group_expr in group_exprs]\n        )\n\n        dc = DataContainer(df, cc)\n\n        if not group_columns:\n            # There was actually no GROUP BY specified in the SQL\n            # Still, this plan can also be used if we need to aggregate something over the full\n            # data sample\n            # To reuse the code, we just create a new column at the end with a single value\n            logger.debug(\"Performing full-table aggregation\")\n\n        # Do all aggregates\n        df_agg, output_column_order, cc = self._do_aggregations(\n            rel,\n            dc,\n            group_columns,\n            context,\n        )\n\n        # SQL does not care about the index, but if group columns were specified we'll want to keep those\n        df_agg = df_agg.reset_index(drop=(not group_columns))\n\n        def try_get_backend_by_frontend_name(oc):\n            try:\n                return cc.get_backend_by_frontend_name(oc)\n            except KeyError:\n                return oc\n\n        backend_output_column_order = [\n            try_get_backend_by_frontend_name(oc) for oc in output_column_order\n        ]\n\n        cc = ColumnContainer(df_agg.columns).limit_to(backend_output_column_order)\n\n        cc = self.fix_column_to_row_type(cc, rel.getRowType())\n        dc = DataContainer(df_agg, cc)\n        dc = self.fix_dtype_to_row_type(dc, rel.getRowType())\n        return dc\n\n    def _do_aggregations(\n        self,\n        rel: \"LogicalPlan\",\n        dc: DataContainer,\n        group_columns: list[str],\n        context: \"dask_sql.Context\",\n    ) -> tuple[dd.DataFrame, list[str]]:\n        \"\"\"\n        Main functionality: return the result dataframe\n        and the output column order\n        \"\"\"\n        df = dc.df\n        cc = dc.column_container\n\n        # We might need it later.\n        # If not, lets hope that adding a single column should not\n        # be a huge problem...\n        additional_column_name = new_temporary_column(df)\n        df = df.assign(**{additional_column_name: 1})\n\n        # Add an entry for every grouped column, as SQL wants them first\n        output_column_order = group_columns.copy()\n\n        # Collect all aggregations we need to do\n        (\n            collected_aggregations,\n            output_column_order,\n            df,\n            cc,\n        ) = self._collect_aggregations(\n            rel, df, cc, context, additional_column_name, output_column_order\n        )\n\n        groupby_agg_options = dask_config.get(\"sql.aggregate\")\n\n        if not collected_aggregations:\n            backend_names = [\n                cc.get_backend_by_frontend_name(group_name)\n                for group_name in group_columns\n            ]\n            return (\n                df[backend_names].drop_duplicates(**groupby_agg_options),\n                output_column_order,\n                cc,\n            )\n\n        # Now we can go ahead and use these grouped aggregations\n        # to perform the actual aggregation\n        # It is very important to start with the non-filtered entry.\n        # Otherwise we might loose some entries in the grouped columns\n        df_result = None\n        key = (None, None)\n        if key in collected_aggregations:\n            aggregations = collected_aggregations.pop(key)\n            df_result = self._perform_aggregation(\n                DataContainer(df, cc),\n                None,\n                None,\n                aggregations,\n                additional_column_name,\n                group_columns,\n                groupby_agg_options,\n            )\n\n        # Now we can also the the rest\n        for (\n            filter_column,\n            distinct_column,\n        ), aggregations in collected_aggregations.items():\n            agg_result = self._perform_aggregation(\n                DataContainer(df, cc),\n                filter_column,\n                distinct_column,\n                aggregations,\n                additional_column_name,\n                group_columns,\n                groupby_agg_options,\n            )\n\n            # ... and finally concat the new data with the already present columns\n            if df_result is None:\n                df_result = agg_result\n            else:\n                df_result = df_result.assign(\n                    **{col: agg_result[col] for col in agg_result.columns}\n                )\n\n        return df_result, output_column_order, cc\n\n    def _collect_aggregations(\n        self,\n        rel: \"LogicalPlan\",\n        df: dd.DataFrame,\n        cc: ColumnContainer,\n        context: \"dask_sql.Context\",\n        additional_column_name: str,\n        output_column_order: list[str],\n    ) -> tuple[\n        dict[tuple[str, str], list[tuple[str, str, Any]]], list[str], dd.DataFrame\n    ]:\n        \"\"\"\n        Collect all aggregations together, which have the same filter column\n        so that the aggregations only need to be done once.\n\n        Returns the aggregations as mapping filter_column -> List of Aggregations\n        where the aggregations are in the form (input_col, output_col, aggregation function (or string))\n        \"\"\"\n        dc = DataContainer(df, cc)\n        agg = rel.aggregate()\n\n        input_rel = rel.get_inputs()[0]\n\n        collected_aggregations = defaultdict(list)\n\n        # convert and assign any input/filter columns that don't currently exist\n        new_columns = {}\n        for expr in agg.getNamedAggCalls():\n            assert expr.getExprType() in {\n                \"Alias\",\n                \"AggregateFunction\",\n                \"AggregateUDF\",\n            }, \"Do not know how to handle this case!\"\n            for input_expr in agg.getArgs(expr):\n                input_col = input_expr.column_name(input_rel)\n                if input_col not in cc._frontend_backend_mapping:\n                    random_name = new_temporary_column(df)\n                    new_columns[random_name] = RexConverter.convert(\n                        input_rel, input_expr, dc, context=context\n                    )\n                    cc = cc.add(input_col, random_name)\n            filter_expr = expr.getFilterExpr()\n            if filter_expr is not None:\n                filter_col = filter_expr.column_name(input_rel)\n                if filter_col not in cc._frontend_backend_mapping:\n                    random_name = new_temporary_column(df)\n                    new_columns[random_name] = RexConverter.convert(\n                        input_rel, filter_expr, dc, context=context\n                    )\n                    cc = cc.add(filter_col, random_name)\n        if new_columns:\n            df = df.assign(**new_columns)\n\n        for expr in agg.getNamedAggCalls():\n            schema_name = context.schema_name\n            aggregation_name = agg.getAggregationFuncName(expr).lower()\n\n            # Gather information about input columns\n            inputs = agg.getArgs(expr)\n\n            if aggregation_name == \"regr_count\":\n                is_null = IsNullOperation()\n                two_columns_proxy = new_temporary_column(df)\n                if len(inputs) == 1:\n                    # calcite some times gives one input/col to regr_count and\n                    # another col has filter column\n                    col1 = cc.get_backend_by_frontend_name(\n                        inputs[0].column_name(input_rel)\n                    )\n                    df = df.assign(**{two_columns_proxy: (~is_null(df[col1]))})\n\n                else:\n                    col1 = cc.get_backend_by_frontend_name(\n                        inputs[0].column_name(input_rel)\n                    )\n                    col2 = cc.get_backend_by_frontend_name(\n                        inputs[1].column_name(input_rel)\n                    )\n                    # both cols should be not null\n                    df = df.assign(\n                        **{\n                            two_columns_proxy: (\n                                ~is_null(df[col1]) & (~is_null(df[col2]))\n                            )\n                        }\n                    )\n                input_col = two_columns_proxy\n            elif aggregation_name == \"regr_syy\":\n                input_col = inputs[0].column_name(input_rel)\n            elif aggregation_name == \"regr_sxx\":\n                input_col = inputs[1].column_name(input_rel)\n            elif len(inputs) == 1:\n                input_col = inputs[0].column_name(input_rel)\n            elif len(inputs) == 0:\n                input_col = additional_column_name\n            else:\n                raise NotImplementedError(\"Can not cope with more than one input\")\n\n            filter_expr = expr.getFilterExpr()\n            if filter_expr is not None:\n                filter_backend_col = cc.get_backend_by_frontend_name(\n                    filter_expr.column_name(input_rel)\n                )\n            else:\n                filter_backend_col = None\n\n            try:\n                # This unifies CPU and GPU behavior by ensuring that performing a\n                # sum on a null column results in null and not 0\n                if aggregation_name == \"sum\" and isinstance(df._meta, pd.DataFrame):\n                    aggregation_function = AggregationSpecification(\n                        dd.Aggregation(\n                            name=\"custom_sum\",\n                            chunk=lambda s: s.sum(min_count=1),\n                            agg=lambda s0: s0.sum(min_count=1),\n                        )\n                    )\n                else:\n                    aggregation_function = self.AGGREGATION_MAPPING[aggregation_name]\n            except KeyError:\n                try:\n                    aggregation_function = context.schema[schema_name].functions[\n                        aggregation_name\n                    ]\n                except KeyError:  # pragma: no cover\n                    raise NotImplementedError(\n                        f\"Aggregation function {aggregation_name} not implemented (yet).\"\n                    )\n            if isinstance(aggregation_function, AggregationSpecification):\n                backend_name = cc.get_backend_by_frontend_name(input_col)\n                aggregation_function = aggregation_function.get_supported_aggregation(\n                    df[backend_name]\n                )\n\n            # Finally, extract the output column name\n            output_col = expr.toString()\n\n            # Store the aggregation\n            collected_aggregations[\n                (filter_backend_col, backend_name if expr.isDistinctAgg() else None)\n            ].append((input_col, output_col, aggregation_function))\n            output_column_order.append(output_col)\n\n        return collected_aggregations, output_column_order, df, cc\n\n    def _perform_aggregation(\n        self,\n        dc: DataContainer,\n        filter_column: str,\n        distinct_column: str,\n        aggregations: list[tuple[str, str, Any]],\n        additional_column_name: str,\n        group_columns: list[str],\n        groupby_agg_options: dict[str, Any] = {},\n    ):\n        tmp_df = dc.df\n\n        # format aggregations for Dask\n        aggregations_dict = defaultdict(dict)\n        for aggregation in aggregations:\n            input_col, output_col, aggregation_f = aggregation\n            input_col = dc.column_container.get_backend_by_frontend_name(input_col)\n\n            # There can be cases where certain Expression values can be present here that\n            # need to remain here until the projection phase. If we get a KeyError here\n            # we assume one of those cases.\n            try:\n                output_col = dc.column_container.get_backend_by_frontend_name(\n                    output_col\n                )\n            except KeyError:\n                logger.debug(f\"Using original output_col value of '{output_col}'\")\n\n            aggregations_dict[input_col][output_col] = aggregation_f\n\n        group_columns = [\n            dc.column_container.get_backend_by_frontend_name(group_name)\n            for group_name in group_columns\n        ]\n\n        # filter dataframe if specified\n        if filter_column:\n            filter_expression = tmp_df[filter_column]\n            tmp_df = tmp_df[filter_expression]\n            logger.debug(f\"Filtered by {filter_column} before aggregation.\")\n        if distinct_column:\n            tmp_df = tmp_df.drop_duplicates(\n                subset=(group_columns + [distinct_column]), **groupby_agg_options\n            )\n            logger.debug(\n                f\"Dropped duplicates from {distinct_column} before aggregation.\"\n            )\n\n        # we might need a temporary column name if no groupby columns are specified\n        if additional_column_name is None:\n            additional_column_name = new_temporary_column(dc.df)\n\n        # perform groupby operation\n        grouped_df = tmp_df.groupby(\n            by=(group_columns or [additional_column_name]), dropna=False\n        )\n\n        # apply the aggregation(s)\n        logger.debug(f\"Performing aggregation {dict(aggregations_dict)}\")\n        agg_result = grouped_df.agg(aggregations_dict, **groupby_agg_options)\n\n        for col in agg_result.columns:\n            logger.debug(col)\n\n        # fix the column names to a single level\n        agg_result.columns = agg_result.columns.get_level_values(-1)\n\n        return agg_result\n"
  },
  {
    "path": "dask_sql/physical/rel/logical/cross_join.py",
    "content": "import logging\nfrom typing import TYPE_CHECKING\n\nimport dask_sql.utils as utils\nfrom dask_sql.datacontainer import ColumnContainer, DataContainer\nfrom dask_sql.physical.rel.base import BaseRelPlugin\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql._datafusion_lib import LogicalPlan\n\nlogger = logging.getLogger(__name__)\n\n\nclass DaskCrossJoinPlugin(BaseRelPlugin):\n    \"\"\"\n    While similar to `DaskJoinPlugin` a `CrossJoin` has enough of a differing\n    structure to justify its own plugin. This in turn limits the number of\n    Dask tasks that are generated for `CrossJoin`'s when compared to a\n    standard `Join`\n    \"\"\"\n\n    class_name = \"CrossJoin\"\n\n    def convert(self, rel: \"LogicalPlan\", context: \"dask_sql.Context\") -> DataContainer:\n        # We now have two inputs (from left and right), so we fetch them both\n        dc_lhs, dc_rhs = self.assert_inputs(rel, 2, context)\n\n        df_lhs = dc_lhs.df\n        df_rhs = dc_rhs.df\n\n        # Create a 'key' column in both DataFrames to join on\n        cross_join_key = utils.new_temporary_column(df_lhs)\n        df_lhs[cross_join_key] = 1\n        df_rhs[cross_join_key] = 1\n\n        result = df_lhs.merge(df_rhs, on=cross_join_key, suffixes=(\"\", \"0\")).drop(\n            cross_join_key, 1\n        )\n        cc = ColumnContainer(result.columns)\n\n        # Rename columns like the rel specifies\n        row_type = rel.getRowType()\n        field_specifications = [str(f) for f in row_type.getFieldNames()]\n\n        cc = cc.rename(\n            {\n                from_col: to_col\n                for from_col, to_col in zip(cc.columns, field_specifications)\n            }\n        )\n        cc = self.fix_column_to_row_type(cc, row_type)\n        return DataContainer(result, cc)\n"
  },
  {
    "path": "dask_sql/physical/rel/logical/empty.py",
    "content": "import logging\nfrom typing import TYPE_CHECKING\n\nimport dask.dataframe as dd\nimport pandas as pd\n\nfrom dask_sql.datacontainer import ColumnContainer, DataContainer\nfrom dask_sql.physical.rel.base import BaseRelPlugin\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql._datafusion_lib import LogicalPlan\n\nlogger = logging.getLogger(__name__)\n\n\nclass DaskEmptyRelationPlugin(BaseRelPlugin):\n    \"\"\"\n    When a SQL query does not contain a target table, this plugin is invoked to\n    create an empty DataFrame that the remaining expressions can operate against.\n    \"\"\"\n\n    class_name = \"EmptyRelation\"\n\n    def convert(self, rel: \"LogicalPlan\", context: \"dask_sql.Context\") -> DataContainer:\n        col_names = (\n            rel.empty_relation().emptyColumnNames()\n            if len(rel.empty_relation().emptyColumnNames()) > 0\n            else [\"_empty\"]\n        )\n        data = None if len(rel.empty_relation().emptyColumnNames()) > 0 else [0]\n        return DataContainer(\n            dd.from_pandas(pd.DataFrame(data, columns=col_names), npartitions=1),\n            ColumnContainer(col_names),\n        )\n"
  },
  {
    "path": "dask_sql/physical/rel/logical/explain.py",
    "content": "from typing import TYPE_CHECKING\n\nfrom dask_sql.physical.rel.base import BaseRelPlugin\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql._datafusion_lib import LogicalPlan\n\n\nclass ExplainPlugin(BaseRelPlugin):\n    \"\"\"\n    Explain is used to explain the query with the EXPLAIN keyword\n    \"\"\"\n\n    class_name = \"Explain\"\n\n    def convert(self, rel: \"LogicalPlan\", context: \"dask_sql.Context\"):\n        explain_strings = rel.explain().getExplainString()\n        return \"\\n\".join(explain_strings)\n"
  },
  {
    "path": "dask_sql/physical/rel/logical/filter.py",
    "content": "import logging\nfrom typing import TYPE_CHECKING, List, Union\n\nimport dask.config as dask_config\nimport dask.dataframe as dd\nimport numpy as np\n\nfrom dask_sql.datacontainer import DataContainer\nfrom dask_sql.physical.rel.base import BaseRelPlugin\nfrom dask_sql.physical.rex import RexConverter\nfrom dask_sql.physical.utils.filter import attempt_predicate_pushdown\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql._datafusion_lib import LogicalPlan\n\nlogger = logging.getLogger(__name__)\n\n\ndef filter_or_scalar(\n    df: dd.DataFrame,\n    filter_condition: Union[np.bool_, dd.Series],\n    add_filters: List = None,\n):\n    \"\"\"\n    Some (complex) SQL queries can lead to a strange condition which is always true or false.\n    We do not need to filter in this case.\n    See https://github.com/dask-contrib/dask-sql/issues/87.\n    \"\"\"\n    if np.isscalar(filter_condition):\n        if not filter_condition:  # pragma: no cover\n            # empty dataset\n            logger.warning(\"Join condition is always false - returning empty dataset\")\n            return df.head(0, compute=False)\n        else:\n            return df\n\n    # In SQL, a NULL in a boolean is False on filtering\n    filter_condition = filter_condition.fillna(False)\n    out = df[filter_condition]\n    # dask-expr should implicitly handle predicate pushdown\n    if dask_config.get(\"sql.predicate_pushdown\") and not dd._dask_expr_enabled():\n        return attempt_predicate_pushdown(out, add_filters=add_filters)\n    else:\n        return out\n\n\nclass DaskFilterPlugin(BaseRelPlugin):\n    \"\"\"\n    DaskFilter is used on WHERE clauses.\n    We just evaluate the filter (which is of type RexNode) and apply it\n    \"\"\"\n\n    class_name = \"Filter\"\n\n    def convert(\n        self,\n        rel: \"LogicalPlan\",\n        context: \"dask_sql.Context\",\n    ) -> DataContainer:\n        (dc,) = self.assert_inputs(rel, 1, context)\n        df = dc.df\n        cc = dc.column_container\n\n        filter = rel.filter()\n\n        # Every logic is handled in the RexConverter\n        # we just need to apply it here\n        condition = filter.getCondition()\n        df_condition = RexConverter.convert(rel, condition, dc, context=context)\n        df = filter_or_scalar(df, df_condition)\n\n        cc = self.fix_column_to_row_type(cc, rel.getRowType())\n        return DataContainer(df, cc)\n"
  },
  {
    "path": "dask_sql/physical/rel/logical/join.py",
    "content": "import logging\nimport operator\nimport warnings\nfrom functools import reduce\nfrom typing import TYPE_CHECKING\n\nimport dask.dataframe as dd\nfrom dask import config as dask_config\n\nfrom dask_sql.datacontainer import ColumnContainer, DataContainer\nfrom dask_sql.physical.rel.base import BaseRelPlugin\nfrom dask_sql.physical.rel.logical.filter import filter_or_scalar\nfrom dask_sql.physical.rex import RexConverter\nfrom dask_sql.utils import is_cudf_type\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql._datafusion_lib import Expression, LogicalPlan\n\nlogger = logging.getLogger(__name__)\n\n\nclass DaskJoinPlugin(BaseRelPlugin):\n    \"\"\"\n    A DaskJoin is used when (surprise) joining two tables.\n    SQL allows for quite complicated joins with difficult conditions.\n    dask/pandas only knows about equijoins on a specific column.\n\n    We use a trick, which is also used in e.g. blazingSQL:\n    we split the join condition into two parts:\n    * everything which is an equijoin\n    * the rest\n    The first part is then used for the dask merging,\n    whereas the second part is just applied as a filter afterwards.\n    This will make joining more time-consuming that is needs to be\n    but so far, it is the only solution...\n    \"\"\"\n\n    class_name = \"Join\"\n\n    JOIN_TYPE_MAPPING = {\n        \"INNER\": \"inner\",\n        \"LEFT\": \"left\",\n        \"RIGHT\": \"right\",\n        \"FULL\": \"outer\",\n        \"LEFTSEMI\": \"leftsemi\",\n        \"LEFTANTI\": \"leftanti\",\n    }\n\n    def convert(self, rel: \"LogicalPlan\", context: \"dask_sql.Context\") -> DataContainer:\n        # Joining is a bit more complicated, so lets do it in steps:\n\n        join = rel.join()\n\n        # 1. We now have two inputs (from left and right), so we fetch them both\n        dc_lhs, dc_rhs = self.assert_inputs(rel, 2, context)\n        cc_lhs = dc_lhs.column_container\n        cc_rhs = dc_rhs.column_container\n\n        # 2. dask's merge will do some smart things with columns, which have the same name\n        # on lhs an rhs (which also includes reordering).\n        # However, that will confuse our column numbering in SQL.\n        # So we make our life easier by converting the column names into unique names\n        # We will convert back in the end\n        cc_lhs_renamed = cc_lhs.make_unique(\"lhs\")\n        cc_rhs_renamed = cc_rhs.make_unique(\"rhs\")\n\n        dc_lhs_renamed = DataContainer(dc_lhs.df, cc_lhs_renamed)\n        dc_rhs_renamed = DataContainer(dc_rhs.df, cc_rhs_renamed)\n\n        df_lhs_renamed = dc_lhs_renamed.assign()\n        df_rhs_renamed = dc_rhs_renamed.assign()\n\n        join_type = join.getJoinType()\n        join_type = self.JOIN_TYPE_MAPPING[str(join_type)]\n        # TODO: update with correct implementation of leftsemi for CPU\n        # https://github.com/dask-contrib/dask-sql/issues/1190\n        if join_type == \"leftsemi\" and not is_cudf_type(df_lhs_renamed):\n            join_type = \"inner\"\n\n        # 3. The join condition can have two forms, that we can understand\n        # (a) a = b\n        # (b) X AND Y AND a = b AND Z ... (can also be multiple a = b)\n        # The first case is very simple and we do not need any additional filter\n        # In the second case we do a merge on all the a = b,\n        # and then apply a filter using the other expressions.\n        # In all other cases, we need to do a full table cross join and filter afterwards.\n        # As this is probably non-sense for large tables, but there is no other\n        # known solution so far.\n\n        join_condition = join.getCondition()\n        lhs_on, rhs_on, filter_condition = None, None, None\n        # A user can write certain queries that really should be `cross join` queries\n        # that will still enter this portion of the logic. IF the join_condition is\n        # None that means there are no conditions to join on. This means a cross join.\n        # By not entering this body during that condition we ensure that later on in\n        # processing we perform a cross join.\n        if join_condition is not None:\n            lhs_on, rhs_on, filter_condition = self._split_join_condition(\n                join_condition\n            )\n\n            # lhs_on and rhs_on are the indices of the columns to merge on.\n            # The given column indices are for the full, merged table which consists\n            # of lhs and rhs put side-by-side (in this order)\n            # We therefore need to normalize the rhs indices relative to the rhs table.\n            rhs_on = [index - len(df_lhs_renamed.columns) for index in rhs_on]\n\n            # 4. dask can only merge on the same column names.\n            # We therefore create new columns on purpose, which have a distinct name.\n            assert len(lhs_on) == len(rhs_on)\n\n        if lhs_on:\n            # 5. Now we can finally merge on these columns\n            # The resulting dataframe will contain all (renamed) columns from the lhs and rhs\n            # plus the added columns\n            df = self._join_on_columns(\n                df_lhs_renamed,\n                df_rhs_renamed,\n                lhs_on,\n                rhs_on,\n                join_type,\n            )\n        else:\n            # 5. We are in the complex join case\n            # where we have no column to merge on\n            # This means we have no other chance than to merge\n            # everything with everything...\n\n            # TODO: we should implement a shortcut\n            # for filter conditions that are always false\n\n            df = dd.merge(\n                df_lhs_renamed.assign(common=1),\n                df_rhs_renamed.assign(common=1),\n                on=\"common\",\n            ).drop(columns=\"common\")\n\n            warnings.warn(\n                \"Need to do a cross-join, which is typically very resource heavy\",\n                ResourceWarning,\n            )\n\n        # 6. So the next step is to make sure\n        # we have the correct column order (and to remove the temporary join columns)\n        if join_type in (\"leftsemi\", \"leftanti\"):\n            correct_column_order = list(df_lhs_renamed.columns)\n        else:\n            correct_column_order = list(df_lhs_renamed.columns) + list(\n                df_rhs_renamed.columns\n            )\n        cc = ColumnContainer(df.columns).limit_to(correct_column_order)\n\n        # and to rename them like the rel specifies\n        row_type = rel.getRowType()\n        field_specifications = [str(f) for f in row_type.getFieldNames()]\n        if join_type in (\"leftsemi\", \"leftanti\"):\n            field_specifications = field_specifications[: len(cc.columns)]\n\n        cc = cc.rename(\n            {\n                from_col: to_col\n                for from_col, to_col in zip(cc.columns, field_specifications)\n            }\n        )\n        cc = self.fix_column_to_row_type(cc, row_type, join_type)\n        dc = DataContainer(df, cc)\n\n        # 7. Last but not least we apply any filters by and-chaining together the filters\n        if filter_condition:\n            # This line is a bit of code duplication with RexCallPlugin - but I guess it is worth to keep it separate\n            filter_condition = reduce(\n                operator.and_,\n                [\n                    RexConverter.convert(rel, rex, dc, context=context)\n                    for rex in filter_condition\n                ],\n            )\n            logger.debug(f\"Additionally applying filter {filter_condition}\")\n            df = filter_or_scalar(df, filter_condition)\n            dc = DataContainer(df, cc)\n\n        dc = self.fix_dtype_to_row_type(dc, rel.getRowType(), join_type)\n        # # Rename underlying DataFrame column names back to their original values before returning\n        # df = dc.assign()\n        # dc = DataContainer(df, ColumnContainer(cc.columns))\n        return dc\n\n    def _join_on_columns(\n        self,\n        df_lhs_renamed: dd.DataFrame,\n        df_rhs_renamed: dd.DataFrame,\n        lhs_on: list[str],\n        rhs_on: list[str],\n        join_type: str,\n    ) -> dd.DataFrame:\n\n        # SQL compatibility: when joining on columns that\n        # contain NULLs, pandas will actually happily\n        # keep those NULLs. That is however not compatible with\n        # SQL, so we get rid of them here\n        if join_type in [\"inner\", \"right\"]:\n            df_lhs_filter = reduce(\n                operator.and_,\n                [~df_lhs_renamed.iloc[:, index].isna() for index in lhs_on],\n            )\n            df_lhs_renamed = df_lhs_renamed[df_lhs_filter]\n        if join_type in [\"inner\", \"left\", \"leftanti\", \"leftsemi\"]:\n            df_rhs_filter = reduce(\n                operator.and_,\n                [~df_rhs_renamed.iloc[:, index].isna() for index in rhs_on],\n            )\n            df_rhs_renamed = df_rhs_renamed[df_rhs_filter]\n\n        lhs_columns_to_add = {\n            f\"common_{i}\": df_lhs_renamed[\"lhs_\" + str(index)]\n            for i, index in enumerate(lhs_on)\n        }\n        rhs_columns_to_add = {\n            f\"common_{i}\": df_rhs_renamed.iloc[:, index]\n            for i, index in enumerate(rhs_on)\n        }\n\n        df_lhs_with_tmp = df_lhs_renamed.assign(**lhs_columns_to_add)\n        df_rhs_with_tmp = df_rhs_renamed.assign(**rhs_columns_to_add)\n        added_columns = list(lhs_columns_to_add.keys())\n\n        broadcast = dask_config.get(\"sql.join.broadcast\")\n        if join_type == \"leftanti\" and not is_cudf_type(df_lhs_with_tmp):\n            df = df_lhs_with_tmp.merge(\n                df_rhs_with_tmp,\n                on=added_columns,\n                how=\"left\",\n                broadcast=broadcast,\n                indicator=True,\n            ).drop(columns=added_columns)\n            df = df[df[\"_merge\"] == \"left_only\"].drop(\n                columns=[\"_merge\"] + list(df_rhs_with_tmp.columns), errors=\"ignore\"\n            )\n        else:\n            df = df_lhs_with_tmp.merge(\n                df_rhs_with_tmp,\n                on=added_columns,\n                how=join_type,\n                broadcast=broadcast,\n            ).drop(columns=added_columns)\n\n        return df\n\n    def _split_join_condition(\n        self, join_condition: \"Expression\"\n    ) -> tuple[list[str], list[str], list[\"Expression\"]]:\n        if str(join_condition.getRexType()) in [\"RexType.Literal\", \"RexType.Reference\"]:\n            return [], [], [join_condition]\n        elif not str(join_condition.getRexType()) == \"RexType.Call\":\n            raise NotImplementedError(\"Can not understand join condition.\")\n\n        lhs_on = []\n        rhs_on = []\n        filter_condition = []\n        try:\n            lhs_on, rhs_on, filter_condition_part = self._extract_lhs_rhs(\n                join_condition\n            )\n            filter_condition.extend(filter_condition_part)\n        except AssertionError:\n            filter_condition.append(join_condition)\n\n        if lhs_on and rhs_on:\n            return lhs_on, rhs_on, filter_condition\n\n        return [], [], [join_condition]\n\n    def _extract_lhs_rhs(self, rex):\n        assert str(rex.getRexType()) == \"RexType.Call\"\n\n        operator_name = str(rex.getOperatorName())\n        assert operator_name in [\"=\", \"AND\"]\n\n        operands = rex.getOperands()\n        assert len(operands) == 2\n\n        if operator_name == \"=\":\n\n            operand_lhs = operands[0]\n            operand_rhs = operands[1]\n\n            if (\n                str(operand_lhs.getRexType()) == \"RexType.Reference\"\n                and str(operand_rhs.getRexType()) == \"RexType.Reference\"\n            ):\n                lhs_index = operand_lhs.getIndex()\n                rhs_index = operand_rhs.getIndex()\n\n                # The rhs table always comes after the lhs\n                # table. Therefore we have a very simple\n                # way of checking, which index comes from which\n                # input\n                if lhs_index > rhs_index:\n                    lhs_index, rhs_index = rhs_index, lhs_index\n\n                return [lhs_index], [rhs_index], []\n\n            raise AssertionError(\n                \"Invalid join condition\"\n            )  # pragma: no cover. Do not how how it could be triggered.\n        else:\n            lhs_indices = []\n            rhs_indices = []\n            filter_conditions = []\n            for operand in operands:\n                try:\n                    lhs_index, rhs_index, filter_condition = self._extract_lhs_rhs(\n                        operand\n                    )\n                    filter_conditions.extend(filter_condition)\n                    lhs_indices.extend(lhs_index)\n                    rhs_indices.extend(rhs_index)\n                except AssertionError:\n                    filter_conditions.append(operand)\n\n            return lhs_indices, rhs_indices, filter_conditions\n"
  },
  {
    "path": "dask_sql/physical/rel/logical/limit.py",
    "content": "from typing import TYPE_CHECKING\n\nimport dask.dataframe as dd\nfrom dask import config as dask_config\nfrom dask.blockwise import Blockwise\nfrom dask.highlevelgraph import MaterializedLayer\nfrom dask.layers import DataFrameIOLayer\n\nfrom dask_sql.datacontainer import DataContainer\nfrom dask_sql.physical.rel.base import BaseRelPlugin\nfrom dask_sql.physical.rex import RexConverter\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql._datafusion_lib import LogicalPlan\n\n\nclass DaskLimitPlugin(BaseRelPlugin):\n    \"\"\"\n    Limit is used to only get a certain part of the dataframe\n    (LIMIT).\n    \"\"\"\n\n    class_name = \"Limit\"\n\n    def convert(self, rel: \"LogicalPlan\", context: \"dask_sql.Context\") -> DataContainer:\n        (dc,) = self.assert_inputs(rel, 1, context)\n        df = dc.df\n        cc = dc.column_container\n\n        # Retrieve the RexType::Literal values from the `LogicalPlan` Limit\n        # Fetch -> LIMIT\n        # Skip -> OFFSET\n        limit = RexConverter.convert(rel, rel.limit().getFetch(), df, context=context)\n        offset = RexConverter.convert(rel, rel.limit().getSkip(), df, context=context)\n\n        # apply offset to limit if specified\n        if limit and offset:\n            limit += offset\n\n        # apply limit and/or offset to DataFrame\n        df = self._apply_limit(df, limit, offset)\n        cc = self.fix_column_to_row_type(cc, rel.getRowType())\n\n        # No column type has changed, so no need to cast again\n        return DataContainer(df, cc)\n\n    def _apply_limit(self, df: dd.DataFrame, limit: int, offset: int) -> dd.DataFrame:\n        \"\"\"\n        Limit the dataframe to the window [offset, limit].\n\n        Unfortunately, Dask does not currently support row selection through `iloc`, so this must be done using a custom partition function.\n        However, it is sometimes possible to compute this window using `head` when an `offset` is not specified.\n        \"\"\"\n        # if no offset is specified we can use `head` to compute the window\n        if not offset:\n            # if `check-first-partition` enabled, check if we have a relatively simple Dask graph and if so,\n            # check if the first partition contains our desired window\n            if (\n                dask_config.get(\"sql.limit.check-first-partition\")\n                and not dd._dask_expr_enabled()\n                and all(\n                    [\n                        isinstance(\n                            layer, (DataFrameIOLayer, Blockwise, MaterializedLayer)\n                        )\n                        for layer in df.dask.layers.values()\n                    ]\n                )\n                and limit <= len(df.partitions[0])\n            ):\n                return df.head(limit, compute=False)\n\n            return df.head(limit, npartitions=-1, compute=False)\n\n        # compute the size of each partition\n        # TODO: compute `cumsum` here when dask#9067 is resolved\n        partition_borders = df.map_partitions(lambda x: len(x))\n\n        def limit_partition_func(df, partition_borders, partition_info=None):\n            \"\"\"Limit the partition to values contained within the specified window, returning an empty dataframe if there are none\"\"\"\n\n            # with dask-expr we may need to explicitly compute here\n            if hasattr(partition_borders, \"compute\"):\n                partition_borders = partition_borders.compute()\n\n            # TODO: remove the `cumsum` call here when dask#9067 is resolved\n            partition_borders = partition_borders.cumsum().to_dict()\n            partition_index = (\n                partition_info[\"number\"] if partition_info is not None else 0\n            )\n\n            partition_border_left = (\n                partition_borders[partition_index - 1] if partition_index > 0 else 0\n            )\n            partition_border_right = partition_borders[partition_index]\n\n            if (limit and limit < partition_border_left) or (\n                offset >= partition_border_right\n            ):\n                return df.iloc[0:0]\n\n            from_index = max(offset - partition_border_left, 0)\n            to_index = (\n                min(limit, partition_border_right) if limit else partition_border_right\n            ) - partition_border_left\n\n            return df.iloc[from_index:to_index]\n\n        return df.map_partitions(\n            limit_partition_func,\n            partition_borders=partition_borders,\n        )\n"
  },
  {
    "path": "dask_sql/physical/rel/logical/project.py",
    "content": "import logging\nfrom typing import TYPE_CHECKING\n\nfrom dask_sql._datafusion_lib import RexType\nfrom dask_sql.datacontainer import DataContainer\nfrom dask_sql.physical.rel.base import BaseRelPlugin\nfrom dask_sql.physical.rex import RexConverter\nfrom dask_sql.utils import new_temporary_column\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql._datafusion_lib import LogicalPlan\n\nlogger = logging.getLogger(__name__)\n\n\nclass DaskProjectPlugin(BaseRelPlugin):\n    \"\"\"\n    A DaskProject is used to\n    (a) apply expressions to the columns and\n    (b) only select a subset of the columns\n    \"\"\"\n\n    class_name = \"Projection\"\n\n    def convert(self, rel: \"LogicalPlan\", context: \"dask_sql.Context\") -> DataContainer:\n        # Get the input of the previous step\n        (dc,) = self.assert_inputs(rel, 1, context)\n\n        df = dc.df\n        cc = dc.column_container\n\n        # Collect all (new) columns\n        proj = rel.projection()\n        named_projects = proj.getNamedProjects()\n\n        column_names = []\n        new_columns = {}\n        new_mappings = {}\n\n        # Collect all (new) columns this Projection will limit to\n        for key, expr in named_projects:\n            key = str(key)\n            column_names.append(key)\n\n            # shortcut: if we have a column already, there is no need to re-assign it again\n            # this is only the case if the expr is a RexInputRef\n            if expr.getRexType() == RexType.Reference:\n                index = expr.getIndex()\n                backend_column_name = cc.get_backend_by_frontend_index(index)\n                logger.debug(\n                    f\"Not re-adding the same column {key} (but just referencing it)\"\n                )\n                new_mappings[key] = backend_column_name\n            else:\n                random_name = new_temporary_column(df)\n                new_columns[random_name] = RexConverter.convert(\n                    rel, expr, dc, context=context\n                )\n                logger.debug(f\"Adding a new column {key} out of {expr}\")\n                new_mappings[key] = random_name\n\n        # Actually add the new columns\n        if new_columns:\n            df = df.assign(**new_columns)\n\n        # and the new mappings\n        for key, backend_column_name in new_mappings.items():\n            cc = cc.add(key, backend_column_name)\n\n        # Make sure the order is correct\n        cc = cc.limit_to(column_names)\n\n        cc = self.fix_column_to_row_type(cc, rel.getRowType())\n        dc = DataContainer(df, cc)\n        dc = self.fix_dtype_to_row_type(dc, rel.getRowType())\n\n        return dc\n"
  },
  {
    "path": "dask_sql/physical/rel/logical/sample.py",
    "content": "import logging\nfrom typing import TYPE_CHECKING\n\nimport numpy as np\n\nfrom dask_sql.datacontainer import DataContainer\nfrom dask_sql.physical.rel.base import BaseRelPlugin\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql.java import org\n\nlogger = logging.getLogger(__name__)\n\n\nclass SamplePlugin(BaseRelPlugin):\n    \"\"\"\n    Sample is used on TABLESAMPLE clauses.\n    It returns only a fraction of the table, given by the\n    number in the arguments.\n    There exist two algorithms, SYSTEM or BERNOULLI.\n\n    SYSTEM is a very fast algorithm, which works on partition\n    level: a partition is kept with a probability given by the\n    percentage. This algorithm will - especially for very small\n    numbers of partitions - give wrong results. Only choose\n    it when you really have too much data to apply BERNOULLI\n    (which might never be the case in real world applications).\n\n    BERNOULLI samples each row separately and will still\n    give only an approximate fraction, but much closer to\n    the expected.\n    \"\"\"\n\n    class_name = \"com.dask.sql.nodes.DaskSample\"\n\n    def convert(\n        self, rel: \"org.apache.calcite.rel.RelNode\", context: \"dask_sql.Context\"\n    ) -> DataContainer:\n        (dc,) = self.assert_inputs(rel, 1, context)\n        df = dc.df\n        cc = dc.column_container\n\n        parameters = rel.getSamplingParameters()\n        is_bernoulli = parameters.isBernoulli()\n        fraction = float(parameters.getSamplingPercentage())\n        seed = parameters.getRepeatableSeed() if parameters.isRepeatable() else None\n\n        if is_bernoulli:\n            df = df.sample(frac=fraction, replace=False, random_state=seed)\n        else:\n            random_state = np.random.RandomState(seed)\n            random_choice = random_state.choice(\n                [True, False],\n                size=df.npartitions,\n                replace=True,\n                p=[fraction, 1 - fraction],\n            )\n\n            if random_choice.any():\n                df = df.partitions[random_choice]\n            else:\n                df = df.head(0, compute=False)\n\n        return DataContainer(df, cc)\n"
  },
  {
    "path": "dask_sql/physical/rel/logical/sort.py",
    "content": "from typing import TYPE_CHECKING\n\nfrom dask_sql.datacontainer import DataContainer\nfrom dask_sql.physical.rel.base import BaseRelPlugin\nfrom dask_sql.physical.utils.sort import apply_sort\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql._datafusion_lib import LogicalPlan\n\n\nclass DaskSortPlugin(BaseRelPlugin):\n    \"\"\"\n    DaskSort is used to sort by columns (ORDER BY).\n    \"\"\"\n\n    class_name = \"Sort\"\n\n    def convert(self, rel: \"LogicalPlan\", context: \"dask_sql.Context\") -> DataContainer:\n        (dc,) = self.assert_inputs(rel, 1, context)\n        df = dc.df\n        cc = dc.column_container\n        sort_plan = rel.sort()\n        sort_expressions = sort_plan.getCollation()\n        sort_columns = [\n            cc.get_backend_by_frontend_name(expr.column_name(rel))\n            for expr in sort_expressions\n        ]\n        sort_ascending = [expr.isSortAscending() for expr in sort_expressions]\n        sort_null_first = [expr.isSortNullsFirst() for expr in sort_expressions]\n        sort_num_rows = sort_plan.getNumRows()\n\n        df = apply_sort(\n            df, sort_columns, sort_ascending, sort_null_first, sort_num_rows\n        )\n\n        cc = self.fix_column_to_row_type(cc, rel.getRowType())\n        # No column type has changed, so no need to cast again\n        return DataContainer(df, cc)\n"
  },
  {
    "path": "dask_sql/physical/rel/logical/subquery_alias.py",
    "content": "from typing import TYPE_CHECKING\n\nfrom dask_sql.datacontainer import DataContainer\nfrom dask_sql.physical.rel.base import BaseRelPlugin\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql._datafusion_lib import LogicalPlan\n\n\nclass SubqueryAlias(BaseRelPlugin):\n    \"\"\"\n    SubqueryAlias is used to assign an alias to a table and/or subquery\n    \"\"\"\n\n    class_name = \"SubqueryAlias\"\n\n    def convert(self, rel: \"LogicalPlan\", context: \"dask_sql.Context\"):\n        (dc,) = self.assert_inputs(rel, 1, context)\n\n        cc = dc.column_container\n\n        alias = rel.subquery_alias().getAlias()\n\n        return DataContainer(\n            dc.df,\n            cc.rename(\n                {\n                    col: renamed_col\n                    for col, renamed_col in zip(\n                        cc.columns,\n                        (f\"{alias}.{col.split('.')[-1]}\" for col in cc.columns),\n                    )\n                }\n            ),\n        )\n"
  },
  {
    "path": "dask_sql/physical/rel/logical/table_scan.py",
    "content": "import logging\nimport operator\nfrom functools import reduce\nfrom typing import TYPE_CHECKING\n\nfrom dask.dataframe import _dask_expr_enabled\nfrom dask.utils_test import hlg_layer\n\nfrom dask_sql.datacontainer import DataContainer\nfrom dask_sql.physical.rel.base import BaseRelPlugin\nfrom dask_sql.physical.rel.logical.filter import filter_or_scalar\nfrom dask_sql.physical.rex import RexConverter\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql._datafusion_lib import LogicalPlan\n\nlogger = logging.getLogger(__name__)\n\n\nclass DaskTableScanPlugin(BaseRelPlugin):\n    \"\"\"\n    A DaskTableScan is the main ingredient: it will get the data\n    from the database. It is always used, when the SQL looks like\n\n        SELECT .... FROM table ....\n\n    We need to get the dask dataframe from the registered\n    tables and return the requested columns from it.\n    \"\"\"\n\n    class_name = \"TableScan\"\n\n    def convert(\n        self,\n        rel: \"LogicalPlan\",\n        context: \"dask_sql.Context\",\n    ) -> DataContainer:\n        # There should not be any input. This is the first step.\n        self.assert_inputs(rel, 0)\n\n        # Rust table_scan instance handle\n        table_scan = rel.table_scan()\n\n        # The table(s) we need to return\n        dask_table = rel.getTable()\n        schema_name, table_name = (n.lower() for n in context.fqn(dask_table))\n\n        dc = context.schema[schema_name].tables[table_name]\n\n        # Apply filter before projections since filter columns may not be in projections\n        dc = self._apply_filters(table_scan, rel, dc, context)\n        dc = self._apply_projections(table_scan, dask_table, dc)\n\n        cc = dc.column_container\n        cc = self.fix_column_to_row_type(cc, rel.getRowType())\n        dc = DataContainer(dc.df, cc)\n        dc = self.fix_dtype_to_row_type(dc, rel.getRowType())\n        return dc\n\n    def _apply_projections(self, table_scan, dask_table, dc):\n        # If the 'TableScan' instance contains projected columns only retrieve those columns\n        # otherwise get all projected columns from the 'Projection' instance, which is contained\n        # in the 'RelDataType' instance, aka 'row_type'\n        df = dc.df\n        cc = dc.column_container\n        if table_scan.containsProjections():\n            field_specifications = list(\n                map(cc.get_backend_by_frontend_name, table_scan.getTableScanProjects())\n            )  # Assumes these are column projections only and field names match table column names\n\n            df = df[field_specifications]\n        else:\n            field_specifications = [\n                str(f) for f in dask_table.getRowType().getFieldNames()\n            ]\n        cc = cc.limit_to(field_specifications)\n        return DataContainer(df, cc)\n\n    def _apply_filters(self, table_scan, rel, dc, context):\n        df = dc.df\n        cc = dc.column_container\n        all_filters = table_scan.getFilters()\n        conjunctive_dnf_filters = table_scan.getDNFFilters().filtered_exprs\n        non_dnf_filters = table_scan.getDNFFilters().io_unfilterable_exprs\n\n        if conjunctive_dnf_filters:\n            # Extract the PyExprs from the conjunctive DNF filters\n            filter_exprs = [f[0] for f in conjunctive_dnf_filters]\n            if non_dnf_filters:\n                filter_exprs.extend(non_dnf_filters)\n\n            df_condition = reduce(\n                operator.and_,\n                [\n                    RexConverter.convert(rel, rex, dc, context=context)\n                    for rex in filter_exprs\n                ],\n            )\n            df = filter_or_scalar(\n                df, df_condition, add_filters=[f[1] for f in conjunctive_dnf_filters]\n            )\n        elif all_filters:\n            df_condition = reduce(\n                operator.and_,\n                [\n                    RexConverter.convert(rel, rex, dc, context=context)\n                    for rex in all_filters\n                ],\n            )\n            df = filter_or_scalar(df, df_condition)\n\n        if not _dask_expr_enabled():\n            try:\n                logger.debug(hlg_layer(df.dask, \"read-parquet\").creation_info)\n            except KeyError:\n                pass\n\n        return DataContainer(df, cc)\n"
  },
  {
    "path": "dask_sql/physical/rel/logical/union.py",
    "content": "from typing import TYPE_CHECKING\n\nimport dask.dataframe as dd\n\nfrom dask_sql.datacontainer import ColumnContainer, DataContainer\nfrom dask_sql.physical.rel.base import BaseRelPlugin\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql._datafusion_lib import LogicalPlan\n\n\ndef _extract_df(obj_cc, obj_df, output_field_names):\n    # For concatenating, they should have exactly the same fields\n    assert len(obj_cc.columns) == len(output_field_names)\n    obj_cc = obj_cc.rename(\n        columns={\n            col: output_col\n            for col, output_col in zip(obj_cc.columns, output_field_names)\n        }\n    )\n    obj_dc = DataContainer(obj_df, obj_cc)\n    return obj_dc.assign()\n\n\nclass DaskUnionPlugin(BaseRelPlugin):\n    \"\"\"\n    DaskUnion is used on UNION clauses.\n    It just concatonates the two data frames.\n    \"\"\"\n\n    class_name = \"Union\"\n\n    def convert(self, rel: \"LogicalPlan\", context: \"dask_sql.Context\") -> DataContainer:\n        # Late import to remove cycling dependency\n        from dask_sql.physical.rel.convert import RelConverter\n\n        objs_dc = [\n            RelConverter.convert(input_rel, context) for input_rel in rel.get_inputs()\n        ]\n\n        objs_df = [obj.df for obj in objs_dc]\n        objs_cc = [obj.column_container for obj in objs_dc]\n\n        output_field_names = [str(x) for x in rel.getRowType().getFieldNames()]\n        obj_dfs = []\n        for i, obj_df in enumerate(objs_df):\n            obj_dfs.append(\n                _extract_df(\n                    obj_cc=objs_cc[i],\n                    obj_df=obj_df,\n                    output_field_names=output_field_names,\n                )\n            )\n\n        _ = [self.check_columns_from_row_type(df, rel.getRowType()) for df in obj_dfs]\n\n        df = dd.concat(obj_dfs)\n\n        cc = ColumnContainer(df.columns)\n        cc = self.fix_column_to_row_type(cc, rel.getRowType())\n        dc = DataContainer(df, cc)\n        dc = self.fix_dtype_to_row_type(dc, rel.getRowType())\n        return dc\n"
  },
  {
    "path": "dask_sql/physical/rel/logical/values.py",
    "content": "from typing import TYPE_CHECKING\n\nimport dask.dataframe as dd\nimport pandas as pd\n\nfrom dask_sql.datacontainer import ColumnContainer, DataContainer\nfrom dask_sql.physical.rel.base import BaseRelPlugin\nfrom dask_sql.physical.rex import RexConverter\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql.java import org\n\n\nclass DaskValuesPlugin(BaseRelPlugin):\n    \"\"\"\n    A DaskValue is a table just consisting of\n    raw values (nothing database-dependent).\n    For example\n\n        SELECT 1 + 1;\n\n    We generate a pandas dataframe and a dask\n    dataframe out of it directly here.\n    We assume that this will only ever be used for small\n    data samples.\n    \"\"\"\n\n    class_name = \"com.dask.sql.nodes.DaskValues\"\n\n    def convert(\n        self, rel: \"org.apache.calcite.rel.RelNode\", context: \"dask_sql.Context\"\n    ) -> DataContainer:\n        # There should not be any input. This is the first step.\n        self.assert_inputs(rel, 0)\n\n        rex_expression_rows = list(rel.getTuples())\n        rows = []\n        for rex_expression_row in rex_expression_rows:\n            # We convert each of the cells in the row\n            # using a RexConverter.\n            # As we do not have any information on the\n            # column headers, we just name them with\n            # their index.\n            rows.append(\n                {\n                    str(i): RexConverter.convert(rex_cell, None, context=context)\n                    for i, rex_cell in enumerate(rex_expression_row)\n                }\n            )\n\n        # TODO: we explicitely reference pandas and dask here -> might we worth making this more general\n        # We assume here that when using the values plan, the resulting dataframe will be quite small\n        if rows:\n            df = pd.DataFrame(rows)\n        else:\n            field_names = [str(x) for x in rel.getRowType().getFieldNames()]\n            df = pd.DataFrame(columns=field_names)\n\n        df = dd.from_pandas(df, npartitions=1)\n        cc = ColumnContainer(df.columns)\n\n        cc = self.fix_column_to_row_type(cc, rel.getRowType())\n        dc = DataContainer(df, cc)\n        dc = self.fix_dtype_to_row_type(dc, rel.getRowType())\n        return dc\n"
  },
  {
    "path": "dask_sql/physical/rel/logical/window.py",
    "content": "import logging\nfrom collections import namedtuple\nfrom functools import partial\nfrom typing import TYPE_CHECKING, Callable, Optional\n\nimport dask.dataframe as dd\nimport numpy as np\nimport pandas as pd\nfrom pandas.api.indexers import BaseIndexer\n\nfrom dask_sql.datacontainer import ColumnContainer, DataContainer\nfrom dask_sql.physical.rel.base import BaseRelPlugin\nfrom dask_sql.physical.rex.convert import RexConverter\nfrom dask_sql.physical.utils.sort import sort_partition_func\nfrom dask_sql.utils import LoggableDataFrame, new_temporary_column\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql._datafusion_lib import LogicalPlan\n\nlogger = logging.getLogger(__name__)\n\n\nclass OverOperation:\n    def __call__(self, partitioned_group, *args) -> pd.Series:\n        \"\"\"Call the stored function\"\"\"\n        return self.call(partitioned_group, *args)\n\n\nclass FirstValueOperation(OverOperation):\n    def call(self, partitioned_group, value_col):\n        return partitioned_group[value_col].apply(lambda x: x.iloc[0])\n\n\nclass LastValueOperation(OverOperation):\n    def call(self, partitioned_group, value_col):\n        return partitioned_group[value_col].apply(lambda x: x.iloc[-1])\n\n\nclass SumOperation(OverOperation):\n    def call(self, partitioned_group, value_col):\n        return partitioned_group[value_col].sum()\n\n\nclass CountOperation(OverOperation):\n    def call(self, partitioned_group, value_col=None):\n        if value_col is None:\n            return partitioned_group.count().iloc[:, 0].fillna(0)\n        else:\n            return partitioned_group[value_col].count().fillna(0)\n\n\nclass MaxOperation(OverOperation):\n    def call(self, partitioned_group, value_col):\n        return partitioned_group[value_col].max()\n\n\nclass MinOperation(OverOperation):\n    def call(self, partitioned_group, value_col):\n        return partitioned_group[value_col].min()\n\n\nclass AvgOperation(OverOperation):\n    def call(self, partitioned_group, value_col):\n        return partitioned_group[value_col].mean()\n\n\nclass BoundDescription(\n    namedtuple(\n        \"BoundDescription\",\n        [\"is_unbounded\", \"is_preceding\", \"is_following\", \"is_current_row\", \"offset\"],\n    )\n):\n    \"\"\"\n    Small helper class to wrap a PyWindowFrame\n    object. We can directly ship PyWindowFrame to workers in the future\n    \"\"\"\n\n    pass\n\n\ndef to_bound_description(\n    windowFrame,\n) -> BoundDescription:\n    \"\"\"Convert the PyWindowFrame object to a BoundDescription representation,\n    replacing any literals or references to constants\"\"\"\n    return BoundDescription(\n        is_unbounded=bool(windowFrame.isUnbounded()),\n        is_preceding=bool(windowFrame.isPreceding()),\n        is_following=bool(windowFrame.isFollowing()),\n        is_current_row=bool(windowFrame.isCurrentRow()),\n        offset=windowFrame.getOffset(),\n    )\n\n\nclass Indexer(BaseIndexer):\n    \"\"\"\n    Window description used for complex windows with arbitrary start and end.\n    This class is directly taken from the fugue project.\n    \"\"\"\n\n    def __init__(self, start: int, end: int):\n        super().__init__(self, start=start, end=end)\n\n    def _get_window_bounds(\n        self,\n        num_values: int = 0,\n        min_periods: Optional[int] = None,\n        center: Optional[bool] = None,\n        closed: Optional[str] = None,\n    ) -> tuple[np.ndarray, np.ndarray]:\n        if self.start is None:\n            start = np.zeros(num_values, dtype=np.int64)\n        else:\n            start = np.arange(self.start, self.start + num_values, dtype=np.int64)\n            if self.start < 0:\n                start[: -self.start] = 0\n            elif self.start > 0:\n                start[-self.start :] = num_values\n        if self.end is None:\n            end = np.full(num_values, num_values, dtype=np.int64)\n        else:\n            end = np.arange(self.end + 1, self.end + 1 + num_values, dtype=np.int64)\n            if self.end > 0:\n                end[-self.end :] = num_values\n            elif self.end < 0:\n                end[: -self.end] = 0\n            else:  # pragma: no cover\n                raise AssertionError(\n                    \"This case should have been handled before! Please report this bug\"\n                )\n        return start, end\n\n    def get_window_bounds(\n        self,\n        num_values: int = 0,\n        min_periods: Optional[int] = None,\n        center: Optional[bool] = None,\n        closed: Optional[str] = None,\n        step: Optional[int] = None,\n    ) -> tuple[np.ndarray, np.ndarray]:\n        return self._get_window_bounds(num_values, min_periods, center, closed)\n\n\ndef map_on_each_group(\n    partitioned_group: pd.DataFrame,\n    sort_columns: list[str],\n    sort_ascending: list[bool],\n    sort_null_first: list[bool],\n    lower_bound: BoundDescription,\n    upper_bound: BoundDescription,\n    operations: list[tuple[Callable, str, list[str]]],\n):\n    \"\"\"Internal function mapped on each group of the dataframe after partitioning\"\"\"\n    # Apply sorting\n    if sort_columns:\n        partitioned_group = sort_partition_func(\n            partitioned_group, sort_columns, sort_ascending, sort_null_first\n        )\n\n    # Apply the windowing operation\n    if lower_bound.is_unbounded and (\n        upper_bound.is_current_row or upper_bound.offset == 0\n    ):\n        windowed_group = partitioned_group.expanding(min_periods=1)\n    elif lower_bound.is_preceding and (\n        upper_bound.is_current_row or upper_bound.offset == 0\n    ):\n        windowed_group = partitioned_group.rolling(\n            window=lower_bound.offset + 1,\n            min_periods=1,\n        )\n    else:\n        lower_offset = lower_bound.offset if not lower_bound.is_current_row else 0\n        if lower_bound.is_preceding and lower_offset is not None:\n            lower_offset *= -1\n        upper_offset = upper_bound.offset if not upper_bound.is_current_row else 0\n        if upper_bound.is_preceding and upper_offset is not None:\n            upper_offset *= -1\n\n        indexer = Indexer(lower_offset, upper_offset)\n        windowed_group = partitioned_group.rolling(window=indexer, min_periods=1)\n\n    # Calculate the results\n    new_columns = {}\n    for f, new_column_name, temporary_operand_columns in operations:\n        if f is None:\n            # This is the row_number operator.\n            # We do not need to do any windowing\n            column_result = range(1, len(partitioned_group) + 1)\n        else:\n            column_result = f(windowed_group, *temporary_operand_columns)\n\n        new_columns[new_column_name] = column_result\n\n    # Now apply all columns at once\n    partitioned_group = partitioned_group.assign(**new_columns)\n    return partitioned_group\n\n\nclass DaskWindowPlugin(BaseRelPlugin):\n    \"\"\"\n    A DaskWindow is an expression, which calculates a given function over the dataframe\n    while first optionally partitoning the data and optionally sorting it.\n\n    Expressions like `F OVER (PARTITION BY x ORDER BY y)` apply f on each\n    partition separately and sort by y before applying f. The result of this\n    calculation has however the same length as the input dataframe - it is not an aggregation.\n    Typical examples include ROW_NUMBER and lagging.\n    \"\"\"\n\n    class_name = \"Window\"\n\n    OPERATION_MAPPING = {\n        \"row_number\": None,  # That is the easiest one: we do not even need to have any windowing. We therefore threat it separately\n        \"$sum0\": SumOperation(),\n        \"sum\": SumOperation(),\n        \"count\": CountOperation(),\n        \"max\": MaxOperation(),\n        \"min\": MinOperation(),\n        \"single_value\": FirstValueOperation(),\n        \"first_value\": FirstValueOperation(),\n        \"last_value\": LastValueOperation(),\n        \"avg\": AvgOperation(),\n    }\n\n    def convert(self, rel: \"LogicalPlan\", context: \"dask_sql.Context\") -> DataContainer:\n        (dc,) = self.assert_inputs(rel, 1, context)\n\n        # Output to the right field names right away\n        field_names = rel.getRowType().getFieldNames()\n\n        for window in rel.window().getGroups():\n            dc = self._apply_window(rel, window, dc, field_names, context)\n\n        # Finally, fix the output schema if needed\n        df = dc.df\n        cc = dc.column_container\n        cc = self.fix_column_to_row_type(cc, rel.getRowType())\n        dc = DataContainer(df, cc)\n        dc = self.fix_dtype_to_row_type(dc, rel.getRowType())\n\n        return dc\n\n    def _apply_window(\n        self,\n        rel,\n        window,\n        dc: DataContainer,\n        field_names: list[str],\n        context: \"dask_sql.Context\",\n    ):\n        temporary_columns = []\n\n        df = dc.df\n        cc = dc.column_container\n\n        # Now extract the groupby and order information\n        sort_columns, sort_ascending, sort_null_first = self._extract_ordering(\n            rel, window, cc\n        )\n        logger.debug(\n            f\"Before applying the function, sorting according to {sort_columns}.\"\n        )\n\n        df, group_columns, temporary_columns = self._extract_groupby(\n            df, rel, window, dc, context\n        )\n        logger.debug(\n            f\"Before applying the function, partitioning according to {group_columns}.\"\n        )\n\n        operations, df = self._extract_operations(rel, window, df, dc, context)\n        for _, _, cols in operations:\n            temporary_columns += cols\n\n        newly_created_columns = [new_column for _, new_column, _ in operations]\n\n        logger.debug(f\"Will create {newly_created_columns} new columns\")\n\n        # Default window bounds when not specified as unbound preceding and current row (if no order by)\n        # unbounded preceding and unbounded following if there's an order by\n        if not rel.window().getWindowFrame(window):\n            lower_bound = BoundDescription(\n                is_unbounded=True,\n                is_preceding=True,\n                is_following=False,\n                is_current_row=False,\n                offset=None,\n            )\n            upper_bound = (\n                BoundDescription(\n                    is_unbounded=False,\n                    is_preceding=False,\n                    is_following=False,\n                    is_current_row=True,\n                    offset=None,\n                )\n                if sort_columns\n                else BoundDescription(\n                    is_unbounded=True,\n                    is_preceding=False,\n                    is_following=True,\n                    is_current_row=False,\n                    offset=None,\n                )\n            )\n        else:\n            lower_bound = to_bound_description(\n                rel.window().getWindowFrame(window).getLowerBound(),\n            )\n            upper_bound = to_bound_description(\n                rel.window().getWindowFrame(window).getUpperBound(),\n            )\n\n        # Apply the windowing operation\n        filled_map = partial(\n            map_on_each_group,\n            sort_columns=sort_columns,\n            sort_ascending=sort_ascending,\n            sort_null_first=sort_null_first,\n            lower_bound=lower_bound,\n            upper_bound=upper_bound,\n            operations=operations,\n        )\n\n        # TODO: That is a bit of a hack. We should really use the real column dtype\n        meta = df._meta.assign(**{col: 0.0 for col in newly_created_columns})\n\n        df = df.groupby(group_columns, dropna=False)[df.columns.tolist()].apply(\n            filled_map, meta=meta\n        )\n        logger.debug(\n            f\"Having created a dataframe {LoggableDataFrame(df)} after windowing. Will now drop {temporary_columns}.\"\n        )\n        df = df.drop(columns=temporary_columns).reset_index(drop=True)\n\n        dc = DataContainer(df, cc)\n        df = dc.df\n        cc = dc.column_container\n        for c in newly_created_columns:\n            field_name = field_names[len(cc.columns)]\n            cc = cc.add(field_name, c)\n        dc = DataContainer(df, cc)\n        logger.debug(\n            f\"Removed unneeded columns and registered new ones: {LoggableDataFrame(dc)}.\"\n        )\n        return dc\n\n    def _extract_groupby(\n        self,\n        df: dd.DataFrame,\n        rel,\n        window,\n        dc: DataContainer,\n        context: \"dask_sql.Context\",\n    ) -> tuple[dd.DataFrame, str]:\n        \"\"\"Prepare grouping columns we can later use while applying the main function\"\"\"\n        partition_keys = rel.window().getPartitionExprs(window)\n        if partition_keys:\n            group_columns = [\n                dc.column_container.get_backend_by_frontend_name(o.column_name(rel))\n                for o in partition_keys\n            ]\n            temporary_columns = []\n        else:\n            temp_col = new_temporary_column(df)\n            df = df.assign(**{temp_col: 1})\n            group_columns = [temp_col]\n            temporary_columns = [temp_col]\n\n        return df, group_columns, temporary_columns\n\n    def _extract_ordering(\n        self, rel, window, cc: ColumnContainer\n    ) -> tuple[str, str, str]:\n        \"\"\"Prepare sorting information we can later use while applying the main function\"\"\"\n        logger.debug(\n            \"Error is about to be encountered, FIX me when bindings are available in subsequent PR\"\n        )\n        # TODO: This was commented out for flake8 CI passing and needs to be handled\n        sort_expressions = rel.window().getSortExprs(window)\n        sort_columns = [\n            cc.get_backend_by_frontend_name(expr.column_name(rel))\n            for expr in sort_expressions\n        ]\n        sort_ascending = [expr.isSortAscending() for expr in sort_expressions]\n        sort_null_first = [expr.isSortNullsFirst() for expr in sort_expressions]\n        return sort_columns, sort_ascending, sort_null_first\n\n    def _extract_operations(\n        self,\n        rel,\n        window,\n        df: dd.DataFrame,\n        dc: DataContainer,\n        context: \"dask_sql.Context\",\n    ) -> list[tuple[Callable, str, list[str]]]:\n        # Finally apply the actual function on each group separately\n        operations = []\n\n        # TODO: datafusion returns only window func expression per window\n        # This can be optimized in the physical plan to collect all aggs for a given window\n        operator_name = rel.window().getWindowFuncName(window).lower()\n\n        try:\n            operation = self.OPERATION_MAPPING[operator_name]\n        except KeyError:  # pragma: no cover\n            try:\n                operation = context.schema[context.schema_name].functions[operator_name]\n            except KeyError:  # pragma: no cover\n                raise NotImplementedError(f\"{operator_name} not (yet) implemented\")\n\n        logger.debug(f\"Executing {operator_name} on {str(LoggableDataFrame(df))}\")\n\n        # TODO: can be optimized by re-using already present columns\n        temporary_operand_columns = {\n            new_temporary_column(df): RexConverter.convert(rel, o, dc, context=context)\n            for o in rel.window().getArgs(window)\n        }\n        df = df.assign(**temporary_operand_columns)\n        temporary_operand_columns = list(temporary_operand_columns.keys())\n\n        operations.append(\n            (operation, new_temporary_column(df), temporary_operand_columns)\n        )\n\n        return operations, df\n"
  },
  {
    "path": "dask_sql/physical/rex/__init__.py",
    "content": "from .convert import RexConverter\n"
  },
  {
    "path": "dask_sql/physical/rex/base.py",
    "content": "import logging\nfrom typing import TYPE_CHECKING, Any, Union\n\nimport dask.dataframe as dd\n\nimport dask_sql\nfrom dask_sql.datacontainer import DataContainer\n\nif TYPE_CHECKING:\n    from dask_sql._datafusion_lib import Expression, LogicalPlan\n\nlogger = logging.getLogger(__name__)\n\n\nclass BaseRexPlugin:\n    \"\"\"\n    Base class for all plugins to convert between\n    a RexNode to a python expression (dask dataframe column or raw value).\n\n    Derived classed needs to override the class_name attribute\n    and the convert method.\n    \"\"\"\n\n    class_name = None\n\n    def convert(\n        self,\n        rel: \"LogicalPlan\",\n        rex: \"Expression\",\n        dc: DataContainer,\n        context: \"dask_sql.Context\",\n    ) -> Union[dd.Series, Any]:\n        \"\"\"Base method to implement\"\"\"\n        raise NotImplementedError\n"
  },
  {
    "path": "dask_sql/physical/rex/convert.py",
    "content": "import logging\nfrom typing import TYPE_CHECKING, Any, Union\n\nimport dask.dataframe as dd\n\nfrom dask_sql.datacontainer import DataContainer\nfrom dask_sql.physical.rex.base import BaseRexPlugin\nfrom dask_sql.utils import LoggableDataFrame, Pluggable\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql._datafusion_lib import Expression, LogicalPlan\n\nlogger = logging.getLogger(__name__)\n\n_REX_TYPE_TO_PLUGIN = {\n    \"RexType.Reference\": \"InputRef\",\n    \"RexType.Call\": \"RexCall\",\n    \"RexType.Literal\": \"RexLiteral\",\n    \"RexType.Alias\": \"RexAlias\",\n    \"RexType.ScalarSubquery\": \"ScalarSubquery\",\n}\n\n\nclass RexConverter(Pluggable):\n    \"\"\"\n    Helper to convert from rex to a python expression\n\n    This class stores plugins which can convert from RexNodes to\n    python expression (single values or dask dataframe columns).\n    The stored plugins are assumed to have a class attribute \"class_name\"\n    to control, which java classes they can convert\n    and they are expected to have a convert (instance) method\n    in the form\n\n        def convert(self, rex, df)\n\n    to do the actual conversion.\n    \"\"\"\n\n    @classmethod\n    def add_plugin_class(cls, plugin_class: BaseRexPlugin, replace=True):\n        \"\"\"Convenience function to add a class directly to the plugins\"\"\"\n        logger.debug(f\"Registering REX plugin for {plugin_class.class_name}\")\n        cls.add_plugin(plugin_class.class_name, plugin_class(), replace=replace)\n\n    @classmethod\n    def convert(\n        cls,\n        rel: \"LogicalPlan\",\n        rex: \"Expression\",\n        dc: DataContainer,\n        context: \"dask_sql.Context\",\n    ) -> Union[dd.DataFrame, Any]:\n        \"\"\"\n        Convert the given Expression\n        into a python expression (a dask dataframe)\n        using the stored plugins and the dictionary of\n        registered dask tables.\n        \"\"\"\n        expr_type = _REX_TYPE_TO_PLUGIN[str(rex.getRexType())]\n\n        try:\n            plugin_instance = cls.get_plugin(expr_type)\n        except KeyError:  # pragma: no cover\n            raise NotImplementedError(\n                f\"No conversion for class {expr_type} available (yet).\"\n            )\n\n        logger.debug(\n            f\"Processing REX {rex} using {plugin_instance.__class__.__name__}...\"\n        )\n\n        df = plugin_instance.convert(rel, rex, dc, context=context)\n        logger.debug(f\"Processed REX {rex} into {LoggableDataFrame(df)}\")\n        return df\n"
  },
  {
    "path": "dask_sql/physical/rex/core/__init__.py",
    "content": "from .alias import RexAliasPlugin\nfrom .call import RexCallPlugin\nfrom .input_ref import RexInputRefPlugin\nfrom .literal import RexLiteralPlugin\nfrom .subquery import RexScalarSubqueryPlugin\n\n__all__ = [\n    RexAliasPlugin,\n    RexCallPlugin,\n    RexInputRefPlugin,\n    RexLiteralPlugin,\n    RexScalarSubqueryPlugin,\n]\n"
  },
  {
    "path": "dask_sql/physical/rex/core/alias.py",
    "content": "from typing import TYPE_CHECKING, Any, Union\n\nimport dask.dataframe as dd\n\nfrom dask_sql.datacontainer import DataContainer\nfrom dask_sql.physical.rex import RexConverter\nfrom dask_sql.physical.rex.base import BaseRexPlugin\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql._datafusion_lib import Expression, LogicalPlan\n\n\nclass RexAliasPlugin(BaseRexPlugin):\n    \"\"\"\n    A RexAliasPlugin is an expression, which references a Subquery.\n    This plugin is thin on logic, however keeping with previous patterns\n    we use the plugin approach instead of placing the logic inline\n    \"\"\"\n\n    class_name = \"RexAlias\"\n\n    def convert(\n        self,\n        rel: \"LogicalPlan\",\n        rex: \"Expression\",\n        dc: DataContainer,\n        context: \"dask_sql.Context\",\n    ) -> Union[dd.Series, Any]:\n        # extract the operands; there should only be a single underlying Expression\n        operands = rex.getOperands()\n        assert len(operands) == 1\n\n        sub_rex = operands[0]\n\n        value = RexConverter.convert(rel, sub_rex, dc, context=context)\n\n        if isinstance(value, DataContainer):\n            return value.df\n\n        return value\n"
  },
  {
    "path": "dask_sql/physical/rex/core/call.py",
    "content": "import logging\nimport operator\nimport re\nimport warnings\nfrom datetime import datetime\nfrom functools import partial, reduce\nfrom typing import TYPE_CHECKING, Any, Callable, Union\n\nimport dask.array as da\nimport dask.config as dask_config\nimport dask.dataframe as dd\nimport numpy as np\nimport pandas as pd\nfrom dask.utils import random_state_data\n\nfrom dask_sql._datafusion_lib import SqlTypeName\nfrom dask_sql.datacontainer import DataContainer\nfrom dask_sql.mappings import (\n    cast_column_to_type,\n    sql_to_python_type,\n    sql_to_python_value,\n)\nfrom dask_sql.physical.rel import RelConverter\nfrom dask_sql.physical.rex import RexConverter\nfrom dask_sql.physical.rex.base import BaseRexPlugin\nfrom dask_sql.physical.rex.core.literal import SargPythonImplementation\nfrom dask_sql.utils import (\n    LoggableDataFrame,\n    convert_to_datetime,\n    is_cudf_type,\n    is_datetime,\n    is_frame,\n)\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql._datafusion_lib import Expression, LogicalPlan\n\nlogger = logging.getLogger(__name__)\nSeriesOrScalar = Union[dd.Series, Any]\n\n\ndef as_timelike(op):\n    if isinstance(op, np.int64):\n        return np.timedelta64(op, \"D\")\n    elif isinstance(op, str):\n        try:\n            return np.datetime64(op)\n        except ValueError:\n            op = datetime.strptime(op, \"%Y-%m-%d\")\n            return np.datetime64(op.strftime(\"%Y-%m-%d\"))\n    elif pd.api.types.is_datetime64_dtype(op) or isinstance(op, np.timedelta64):\n        return op\n    else:\n        raise ValueError(f\"Don't know how to make {type(op)} timelike\")\n\n\nclass Operation:\n    \"\"\"Helper wrapper around a function, which is used as operator\"\"\"\n\n    # True, if the operation should also get the dataframe passed\n    needs_dc = False\n\n    # True, if the operation should also get the REX\n    needs_rex = False\n\n    # True, if the operation should also needs the Context, possible subquery Relation expansion\n    needs_context = False\n\n    # True, if the operation needs the original relation algebra\n    needs_rel = False\n\n    @staticmethod\n    def op_needs_dc(op):\n        return hasattr(op, \"needs_dc\") and op.needs_dc\n\n    @staticmethod\n    def op_needs_rex(op):\n        return hasattr(op, \"needs_rex\") and op.needs_rex\n\n    @staticmethod\n    def op_needs_context(op):\n        return hasattr(op, \"needs_context\") and op.needs_context\n\n    @staticmethod\n    def op_needs_rel(op):\n        return hasattr(op, \"needs_rel\") and op.needs_rel\n\n    def __init__(self, f: Callable):\n        \"\"\"Init with the given function\"\"\"\n        self.f = f\n\n    def __call__(self, *operands, **kwargs) -> SeriesOrScalar:\n        \"\"\"Call the stored function\"\"\"\n        return self.f(*operands, **kwargs)\n\n    def of(self, op: \"Operation\") -> \"Operation\":\n        \"\"\"Functional composition\"\"\"\n        new_op = Operation(lambda *x, **kwargs: self(op(*x, **kwargs)))\n        new_op.needs_dc = Operation.op_needs_dc(op)\n        new_op.needs_rex = Operation.op_needs_rex(op)\n        new_op.needs_context = Operation.op_needs_context(op)\n        new_op.needs_rel = Operation.op_needs_rel(op)\n\n        return new_op\n\n\nclass PredicateBasedOperation(Operation):\n    \"\"\"\n    Helper operation to call a function on the input,\n    depending if the first arg evaluates, given a predicate function, to true or false\n    \"\"\"\n\n    def __init__(\n        self, predicate: Callable, true_route: Callable, false_route: Callable\n    ):\n        super().__init__(self.apply)\n        self.predicate = predicate\n        self.true_route = true_route\n        self.false_route = false_route\n\n    def apply(self, *operands, **kwargs):\n        if self.predicate(operands[0]):\n            return self.true_route(*operands, **kwargs)\n\n        return self.false_route(*operands, **kwargs)\n\n\nclass TensorScalarOperation(PredicateBasedOperation):\n    \"\"\"\n    Helper operation to call a function on the input,\n    depending if the first is a dataframe or not\n    \"\"\"\n\n    def __init__(self, tensor_f: Callable, scalar_f: Callable = None):\n        \"\"\"Init with the given operation\"\"\"\n        super().__init__(is_frame, tensor_f, scalar_f)\n\n\nclass ReduceOperation(Operation):\n    \"\"\"Special operator, which is executed by reducing an operation over the input\"\"\"\n\n    def __init__(self, operation: Callable, unary_operation: Callable = None):\n        self.operation = operation\n        self.unary_operation = unary_operation or operation\n        self.needs_dc = Operation.op_needs_dc(self.operation)\n        self.needs_rex = Operation.op_needs_rex(self.operation)\n\n        super().__init__(self.reduce)\n\n    def reduce(self, *operands, **kwargs):\n        if len(operands) > 1:\n            if any(\n                map(\n                    lambda op: is_frame(op) & pd.api.types.is_datetime64_dtype(op),\n                    operands,\n                )\n            ):\n                operands = tuple(map(as_timelike, operands))\n            return reduce(partial(self.operation, **kwargs), operands)\n        else:\n            return self.unary_operation(*operands, **kwargs)\n\n\nclass SQLDivisionOperator(Operation):\n    \"\"\"\n    Division is handled differently in SQL and python.\n    In python3, it will always preserve the full information, even if starting with\n    an integer (so 1/2 = 0.5).\n    In SQL, integer division will return an integer again. However, it is not floor division\n    (where -1/2 = -1), but truncated division (so -1 / 2 = 0).\n    \"\"\"\n\n    needs_rex = True\n\n    def __init__(self):\n        super().__init__(self.div)\n\n    def div(self, lhs, rhs, rex=None):\n        result = lhs / rhs\n\n        output_type = str(rex.getType())\n        output_type = sql_to_python_type(SqlTypeName.fromString(output_type.upper()))\n\n        is_float = pd.api.types.is_float_dtype(output_type)\n        if not is_float:\n            result = da.trunc(result)\n\n        return result\n\n\nclass IntDivisionOperator(Operation):\n    \"\"\"\n    Truncated integer division (so -1 / 2 = 0).\n    This is only used for internal calculations,\n    which are created by Calcite.\n    \"\"\"\n\n    def __init__(self):\n        super().__init__(self.div)\n\n    def div(self, lhs, rhs):\n        result = lhs / rhs\n\n        # Specialized code for literals like \"1000µs\"\n        # For some reasons, Calcite decides to represent\n        # 1000µs as 1000µs * 1000 / 1000\n        # We do not need to truncate in this case\n        # So far, I did not spot any other occurrence\n        # of this function.\n        if isinstance(result, np.timedelta64):\n            return result\n        else:\n            return da.trunc(result).astype(np.int64)\n\n\nclass CaseOperation(Operation):\n    \"\"\"The case operator (basically an if then else)\"\"\"\n\n    def __init__(self):\n        super().__init__(self.case)\n\n    def case(self, *operands) -> SeriesOrScalar:\n        \"\"\"\n        Returns `then` where `where`, else `other`.\n        \"\"\"\n        assert operands\n\n        where = operands[0]\n        then = operands[1]\n\n        if len(operands) > 3:\n            other = self.case(*operands[2:])\n        elif len(operands) == 2:\n            # CASE/WHEN statement without an ELSE\n            other = None\n        else:\n            other = operands[2]\n\n        if is_frame(then):\n            return then.where(where, other=other)\n        elif is_frame(other):\n            return other.where(~where, other=then)\n        elif is_frame(where):\n            # This one is a bit tricky.\n            # Everything except \"where\" are scalars.\n            # To make the \"df.where\" function still usable\n            # we create a temporary dataframe with the\n            # properties of where (but the content of then).\n            tmp = where.apply(lambda x: then, meta=(where.name, type(then)))\n            return tmp.where(where, other=other)\n        else:\n            return then if where else other\n\n\nclass CastOperation(Operation):\n    \"\"\"The cast operator\"\"\"\n\n    needs_rex = True\n\n    def __init__(self):\n        super().__init__(self.cast)\n\n    def cast(self, operand, rex=None) -> SeriesOrScalar:\n        output_type = rex.getType()\n        sql_type = SqlTypeName.fromString(output_type)\n        sql_type_args = ()\n\n        # decimal datatypes require precision and scale\n        if output_type == \"DECIMAL\":\n            sql_type_args = rex.getPrecisionScale()\n\n        if output_type == \"TIMESTAMP\" and pd.api.types.is_integer_dtype(operand):\n            operand = operand * 10**9\n\n        if not is_frame(operand):  # pragma: no cover\n            return sql_to_python_value(sql_type, operand)\n\n        python_type = sql_to_python_type(sql_type, *sql_type_args)\n\n        return_column = cast_column_to_type(operand, python_type)\n\n        if return_column is None:\n            return_column = operand\n\n        # TODO: ideally we don't want to directly access the datetimes,\n        # but Pandas can't truncate timezone datetimes and cuDF can't\n        # truncate datetimes\n        if output_type == \"DATE\":\n            return return_column.dt.floor(\"D\").astype(python_type)\n\n        return return_column\n\n\nclass IsFalseOperation(Operation):\n    \"\"\"The is false operator\"\"\"\n\n    def __init__(self):\n        super().__init__(self.false_)\n\n    def false_(\n        self,\n        df: SeriesOrScalar,\n    ) -> SeriesOrScalar:\n        \"\"\"\n        Returns true where `df` is false (where `df` can also be just a scalar).\n        Returns false on nan.\n        \"\"\"\n        if is_frame(df):\n            return ~df.astype(\"boolean\").fillna(True)\n\n        return not pd.isna(df) and df is not None and not np.isnan(df) and not bool(df)\n\n\nclass IsTrueOperation(Operation):\n    \"\"\"The is true operator\"\"\"\n\n    def __init__(self):\n        super().__init__(self.true_)\n\n    def true_(\n        self,\n        df: SeriesOrScalar,\n    ) -> SeriesOrScalar:\n        \"\"\"\n        Returns true where `df` is true (where `df` can also be just a scalar).\n        Returns false on nan.\n        \"\"\"\n        if is_frame(df):\n            return df.astype(\"boolean\").fillna(False)\n\n        return not pd.isna(df) and df is not None and not np.isnan(df) and bool(df)\n\n\nclass NegativeOperation(Operation):\n    \"\"\"The negative operator\"\"\"\n\n    def __init__(self):\n        super().__init__(self.negative_)\n\n    def negative_(\n        self,\n        df: SeriesOrScalar,\n    ) -> SeriesOrScalar:\n        return -df\n\n\nclass NotOperation(Operation):\n    \"\"\"The not operator\"\"\"\n\n    def __init__(self):\n        super().__init__(self.not_)\n\n    def not_(\n        self,\n        df: SeriesOrScalar,\n    ) -> SeriesOrScalar:\n        \"\"\"\n        Returns not `df` (where `df` can also be just a scalar).\n        \"\"\"\n        if is_frame(df):\n            return ~(df.astype(\"boolean\"))\n        else:\n            return not df\n\n\nclass IsNullOperation(Operation):\n    \"\"\"The is null operator\"\"\"\n\n    def __init__(self):\n        super().__init__(self.null)\n\n    def null(\n        self,\n        df: SeriesOrScalar,\n    ) -> SeriesOrScalar:\n        \"\"\"\n        Returns true where `df` is null (where `df` can also be just a scalar).\n        \"\"\"\n        if is_frame(df):\n            return df.isna()\n\n        return pd.isna(df) or df is None or np.isnan(df)\n\n\nclass IsNotDistinctOperation(Operation):\n    \"\"\"The is not distinct operator\"\"\"\n\n    def __init__(self):\n        super().__init__(self.not_distinct)\n\n    def not_distinct(self, lhs: SeriesOrScalar, rhs: SeriesOrScalar) -> SeriesOrScalar:\n        \"\"\"\n        Returns true where `lhs` is not distinct from `rhs` (or both are null).\n        \"\"\"\n        is_null = IsNullOperation()\n\n        return (is_null(lhs) & is_null(rhs)) | (lhs == rhs)\n\n\nclass RegexOperation(Operation):\n    \"\"\"An abstract regex operation, which transforms the SQL regex into something python can understand\"\"\"\n\n    needs_rex = True\n\n    def __init__(self):\n        super().__init__(self.regex)\n\n    def regex(self, test: SeriesOrScalar, regex: str, rex=None) -> SeriesOrScalar:\n        \"\"\"\n        Returns true, if the string test matches the given regex\n        (maybe escaped by escape)\n        \"\"\"\n        escape = rex.getEscapeChar() if rex else None\n        if not escape:\n            escape = \"\\\\\"\n\n        # Unfortunately, SQL's like syntax is not directly\n        # a regular expression. We need to do some translation\n        # SQL knows about the following wildcards:\n        # %, ?, [], _, #\n        transformed_regex = \"\"\n        escaped = False\n        in_char_range = False\n        for char in regex:\n            # Escape characters with \"\\\"\n            if escaped:\n                char = \"\\\\\" + char\n                escaped = False\n\n            # Keep character ranges [...] as they are\n            elif in_char_range:\n                if char == \"]\":\n                    in_char_range = False\n\n            # These chars have a special meaning in regex\n            # whereas in SQL they have not, so we need to\n            # add additional escaping\n            elif char in self.replacement_chars:\n                char = \"\\\\\" + char\n\n            elif char == \"[\":\n                in_char_range = True\n\n            # The needed \"\\\" is printed above, so we continue\n            elif char == escape:\n                escaped = True\n                continue\n\n            # An unescaped \"%\" in SQL is a .*\n            elif char == \"%\":\n                char = \".*\"\n\n            # An unescaped \"_\" in SQL is a .\n            elif char == \"_\":\n                char = \".\"\n\n            transformed_regex += char\n\n        # the SQL like always goes over the full string\n        transformed_regex = \"^\" + transformed_regex + \"$\"\n\n        # Finally, apply the string\n        flags = re.DOTALL | re.IGNORECASE if not self.case_sensitive else re.DOTALL\n        if is_frame(test):\n            return test.str.match(transformed_regex, flags=flags).astype(\"boolean\")\n        else:\n            return bool(re.match(transformed_regex, test, flags=flags))\n\n\nclass LikeOperation(RegexOperation):\n    def __init__(self, case_sensitive: bool = True):\n        self.case_sensitive = case_sensitive\n        self.replacement_chars = [\n            \"#\",\n            \"$\",\n            \"^\",\n            \".\",\n            \"|\",\n            \"~\",\n            \"-\",\n            \"+\",\n            \"*\",\n            \"?\",\n            \"(\",\n            \")\",\n            \"{\",\n            \"}\",\n            \"[\",\n            \"]\",\n        ]\n        super().__init__()\n\n\nclass SimilarOperation(RegexOperation):\n    replacement_chars = [\n        \"#\",\n        \"$\",\n        \"^\",\n        \".\",\n        \"~\",\n        \"-\",\n    ]\n    case_sensitive = True\n\n\nclass PositionOperation(Operation):\n    \"\"\"The position operator (get the position of a string)\"\"\"\n\n    def __init__(self):\n        super().__init__(self.position)\n\n    def position(self, search, s, start=None):\n        \"\"\"Attention: SQL starts counting at 1\"\"\"\n        if is_frame(s):\n            s = s.str\n\n        if start is None or start <= 0:\n            start = 0\n        else:\n            start -= 1\n\n        return s.find(search, start) + 1\n\n\nclass SubStringOperation(Operation):\n    \"\"\"The substring operator (get a slice of a string)\"\"\"\n\n    def __init__(self):\n        super().__init__(self.substring)\n\n    def substring(self, s, start, length=None):\n        \"\"\"Attention: SQL starts counting at 1\"\"\"\n        if start <= 0:\n            start = 0\n        else:\n            start -= 1\n\n        end = length + start if length else None\n        if is_frame(s):\n            return s.str.slice(start, end)\n\n        if end:\n            return s[start:end]\n        else:\n            return s[start:]\n\n\nclass TrimOperation(Operation):\n    \"\"\"The trim operator (remove occurrences left and right of a string)\"\"\"\n\n    def __init__(self, flag=\"BOTH\"):\n        self.flag = flag\n        super().__init__(self.trim)\n\n    def trim(self, s, search):\n        if is_frame(s):\n            s = s.str\n\n        if self.flag == \"LEADING\":\n            strip_call = s.lstrip\n        elif self.flag == \"TRAILING\":\n            strip_call = s.rstrip\n        elif self.flag == \"BOTH\":\n            strip_call = s.strip\n        else:\n            raise ValueError(f\"Trim type {self.flag} not recognized\")\n\n        return strip_call(search)\n\n\nclass ReplaceOperation(Operation):\n    \"\"\"The replace operator (replace occurrences of pattern in a string)\"\"\"\n\n    def __init__(self):\n        super().__init__(self.replace)\n\n    def replace(self, s, pat, repl):\n        if is_frame(s):\n            s = s.str\n\n        return s.replace(pat, repl)\n\n\nclass OverlayOperation(Operation):\n    \"\"\"The overlay operator (replace string according to positions)\"\"\"\n\n    def __init__(self):\n        super().__init__(self.overlay)\n\n    def overlay(self, s, replace, start, length=None):\n        \"\"\"Attention: SQL starts counting at 1\"\"\"\n        if start <= 0:\n            start = 0\n        else:\n            start -= 1\n\n        if length is None:\n            length = len(replace)\n        end = length + start\n\n        if is_frame(s):\n            return s.str.slice_replace(start, end, replace)\n\n        s = s[:start] + replace + s[end:]\n        return s\n\n\nclass CoalesceOperation(Operation):\n    def __init__(self):\n        super().__init__(self.coalesce)\n\n    def coalesce(self, *operands):\n        result = None\n        for operand in operands:\n            if is_frame(operand):\n                # Check if frame evaluates to nan or NA\n                if len(operand) == 1 and not operand.isnull().all().compute():\n                    return operand if result is None else result.fillna(operand)\n                else:\n                    result = operand if result is None else result.fillna(operand)\n            elif not pd.isna(operand):\n                return operand if result is None else result.fillna(operand)\n\n        return result\n\n\nclass ToTimestampOperation(Operation):\n    def __init__(self):\n        super().__init__(self.to_timestamp)\n\n    def to_timestamp(self, df, format):\n        default_format = \"%Y-%m-%d %H:%M:%S\"\n        # Remove double and single quotes from string\n        format = format.replace('\"', \"\")\n        format = format.replace(\"'\", \"\")\n\n        # String cases\n        if type(df) == str:\n            return np.datetime64(datetime.strptime(df, format))\n        elif df.dtype == \"object\":\n            return dd.to_datetime(df, format=format)\n        # Integer cases\n        elif np.isscalar(df):\n            if format != default_format:\n                raise RuntimeError(\"Integer input does not accept a format argument\")\n            return np.datetime64(int(df), \"s\")\n        else:\n            if format != default_format:\n                raise RuntimeError(\"Integer input does not accept a format argument\")\n            return dd.to_datetime(df, unit=\"s\")\n\n\nclass YearOperation(Operation):\n    def __init__(self):\n        super().__init__(self.extract_year)\n\n    def extract_year(self, df: SeriesOrScalar):\n        df = convert_to_datetime(df)\n        return df.year\n\n\nclass TimeStampAddOperation(Operation):\n    def __init__(self):\n        super().__init__(self.timestampadd)\n\n    def timestampadd(self, unit, interval, df: SeriesOrScalar):\n        unit = unit.upper()\n        interval = int(interval)\n        if interval < 0:\n            raise RuntimeError(f\"Negative time interval {interval} is not supported.\")\n        df = (\n            df.astype(\"datetime64[s]\")\n            if pd.api.types.is_integer_dtype(df)\n            else df.astype(\"datetime64[ns]\")\n        )\n\n        if is_cudf_type(df):\n            from cudf import DateOffset\n        else:\n            from pandas.tseries.offsets import DateOffset\n\n        if unit in {\"YEAR\", \"YEARS\"}:\n            return df + DateOffset(years=interval)\n        elif unit in {\"QUARTER\", \"QUARTERS\", \"MONTH\", \"MONTHS\"}:\n            if unit in {\"QUARTER\", \"QUARTERS\"}:\n                return df + DateOffset(months=interval * 3)\n            else:  # \"MONTH\"\n                return df + DateOffset(months=interval)\n        elif unit in {\"WEEK\", \"WEEKS\", \"SQL_TSI_WEEK\"}:\n            return df + DateOffset(weeks=interval)\n        elif unit in {\"DAY\", \"DAYS\", \"SQL_TSI_DAY\"}:\n            return df + DateOffset(days=interval)\n        elif unit in {\"HOUR\", \"HOURS\", \"SQL_TSI_HOUR\"}:\n            return df + DateOffset(hours=interval)\n        elif unit in {\"MINUTE\", \"MINUTES\", \"SQL_TSI_MINUTE\"}:\n            return df + DateOffset(minutes=interval)\n        elif unit in {\"SECOND\", \"SECONDS\", \"SQL_TSI_SECOND\"}:\n            return df + DateOffset(seconds=interval)\n        elif unit in {\"MILLISECOND\", \"MILLISECONDS\"}:\n            return df + DateOffset(milliseconds=interval)\n        elif unit in {\"MICROSECOND\", \"MICROSECONDS\"}:\n            return df + DateOffset(microseconds=interval)\n        else:\n            raise NotImplementedError(\n                f\"Timestamp addition with {unit} is not supported.\"\n            )\n\n\nclass DatetimeSubOperation(Operation):\n    \"\"\"\n    Datetime subtraction is a special case of the `minus` operation\n    which also specifies a sql interval return type for the operation.\n    \"\"\"\n\n    def __init__(self):\n        super().__init__(self.datetime_sub)\n\n    def datetime_sub(self, unit, df1, df2):\n        if pd.api.types.is_integer_dtype(df1):\n            df1 = df1 * 10**9\n        if pd.api.types.is_integer_dtype(df2):\n            df2 = df2 * 10**9\n        if \"datetime64[s]\" == str(getattr(df1, \"dtype\", \"\")):\n            df1 = df1.astype(\"datetime64[ns]\")\n        if \"datetime64[s]\" == str(getattr(df2, \"dtype\", \"\")):\n            df2 = df2.astype(\"datetime64[ns]\")\n\n        subtraction_op = ReduceOperation(\n            operation=operator.sub, unary_operation=lambda x: -x\n        )\n        result = subtraction_op(df2, df1)\n\n        if is_cudf_type(df1):\n            result = result.astype(\"int\")\n\n        if unit in {\"NANOSECOND\", \"NANOSECONDS\"}:\n            return result\n        elif unit in {\"MICROSECOND\", \"MICROSECONDS\"}:\n            return result // 1_000\n        elif unit in {\"SECOND\", \"SECONDS\"}:\n            return result // 1_000_000_000\n        elif unit in {\"MINUTE\", \"MINUTES\"}:\n            return (result / 1_000_000_000) // 60\n        elif unit in {\"HOUR\", \"HOURS\"}:\n            return (result / 1_000_000_000) // 3600\n        elif unit in {\"DAY\", \"DAYS\"}:\n            return ((result / 1_000_000_000) / 3600) // 24\n        elif unit in {\"WEEK\", \"WEEKS\"}:\n            return (((result / 1_000_000_000) / 3600) / 24) // 7\n        elif unit in {\"MONTH\", \"MONTHS\"}:\n            day_result = ((result / 1_000_000_000) / 3600) // 24\n            avg_days_in_month = ((30 * 4) + 28 + (31 * 7)) / 12\n            return day_result / avg_days_in_month\n        elif unit in {\"QUARTER\", \"QUARTERS\"}:\n            day_result = ((result / 1_000_000_000) / 3600) // 24\n            avg_days_in_quarter = 3 * ((30 * 4) + 28 + (31 * 7)) / 12\n            return day_result / avg_days_in_quarter\n        elif unit in {\"YEAR\", \"YEARS\"}:\n            return (((result / 1_000_000_000) / 3600) / 24) // 365\n        else:\n            raise NotImplementedError(\n                f\"Timestamp difference with {unit} is not supported.\"\n            )\n\n\nclass CeilFloorOperation(PredicateBasedOperation):\n    \"\"\"\n    Apply ceil/floor operations on a series depending on its dtype (datetime like vs normal)\n    \"\"\"\n\n    def __init__(self, round_method: str):\n        assert round_method in {\n            \"ceil\",\n            \"floor\",\n        }, \"Round method can only be either ceil or floor\"\n\n        super().__init__(\n            is_datetime,  # if the series is dt type\n            self._round_datetime,\n            getattr(da, round_method),\n        )\n\n        self.round_method = round_method\n\n    def _round_datetime(self, *operands):\n        df, unit = operands\n\n        df = convert_to_datetime(df)\n\n        unit_map = {\n            \"DAY\": \"D\",\n            \"HOUR\": \"h\",\n            \"MINUTE\": \"min\",\n            \"SECOND\": \"s\",\n            \"MICROSECOND\": \"U\",\n            \"MILLISECOND\": \"ms\",\n        }\n\n        try:\n            freq = unit_map[unit.upper()]\n            return getattr(df, self.round_method)(freq)\n        except KeyError:\n            raise NotImplementedError(\n                f\"{self.round_method} TO {unit} is not (yet) implemented.\"\n            )\n\n\nclass BaseRandomOperation(Operation):\n    \"\"\"\n    Return a random number (specified by the given function) with the random number\n    generator set to the given seed.\n    As we need to know how many random numbers we should generate,\n    we also get the current dataframe as input and use it to\n    create random numbers for each partition separately.\n    To make this deterministic, we use the partition number\n    as additional input to the seed.\n    \"\"\"\n\n    needs_dc = True\n\n    def random_function(self, partition, random_state, kwargs):\n        \"\"\"Needs to be implemented in derived classes\"\"\"\n        raise NotImplementedError\n\n    def random_frame(self, seed: int, dc: DataContainer, **kwargs) -> dd.Series:\n        \"\"\"This function - in contrast to others in this module - will only ever be called on data frames\"\"\"\n        df = dc.df\n        state_data = random_state_data(df.npartitions, np.random.RandomState(seed=seed))\n\n        def random_partition_func(df, state_data, partition_info=None):\n            \"\"\"Create a random number for each partition\"\"\"\n            partition_index = (\n                partition_info[\"number\"] if partition_info is not None else 0\n            )\n\n            state = np.random.RandomState(state_data[partition_index])\n            return self.random_function(df, state, kwargs)\n\n        random_series = df.map_partitions(\n            random_partition_func, state_data, meta=(\"random\", \"float64\")\n        )\n\n        # This part seems to be stupid, but helps us do a very simple\n        # task without going into the (private) internals of Dask:\n        # copy all meta information from the original input dataframe\n        # This is important so that the returned series looks\n        # exactly like coming from the input dataframe\n        return df.assign(random=random_series)[\"random\"]\n\n\nclass RandOperation(BaseRandomOperation):\n    \"\"\"Create a random number between 0 and 1\"\"\"\n\n    def __init__(self):\n        super().__init__(f=self.rand)\n\n    def rand(self, seed: int = None, dc: DataContainer = None):\n        return self.random_frame(seed=seed, dc=dc)\n\n    def random_function(self, partition, random_state, kwargs):\n        return random_state.random_sample(size=len(partition))\n\n\nclass RandIntegerOperation(BaseRandomOperation):\n    \"\"\"Create a random integer between 0 and high\"\"\"\n\n    def __init__(self):\n        super().__init__(f=self.rand_integer)\n\n    def rand_integer(\n        self, seed: int = None, high: int = None, dc: DataContainer = None\n    ):\n        # Two possibilities: RAND_INTEGER(seed, high) or RAND_INTEGER(high)\n        if high is None:\n            high = seed\n            seed = None\n        return self.random_frame(seed=seed, high=high, dc=dc)\n\n    def random_function(self, partition, random_state, kwargs):\n        return random_state.randint(size=len(partition), low=0, **kwargs)\n\n\nclass SearchOperation(Operation):\n    \"\"\"\n    Search is a special operation in SQL, which allows to write \"range-like\"\n    conditions, such like\n\n        (1 < a AND a < 2) OR (4 < a AND a < 6)\n\n    in a more convenient setting.\n    \"\"\"\n\n    def __init__(self):\n        super().__init__(self.search)\n\n    def search(self, series: dd.Series, sarg: SargPythonImplementation):\n        conditions = [r.filter_on(series) for r in sarg.ranges]\n\n        assert len(conditions) > 0\n\n        if len(conditions) > 1:\n            or_operation = ReduceOperation(operation=operator.or_)\n            return or_operation(*conditions)\n        else:\n            return conditions[0]\n\n\nclass ExtractOperation(Operation):\n    \"\"\"\n    Function for performing PostgreSQL like functions in a more convenient setting.\n    \"\"\"\n\n    def __init__(self):\n        super().__init__(self.date_part)\n\n    def date_part(self, what, df: SeriesOrScalar):\n        what = what.upper()\n        df = convert_to_datetime(df)\n\n        if what in {\"YEAR\", \"YEARS\"}:\n            return df.year\n        elif what in {\"CENTURY\", \"CENTURIES\"}:\n            return da.trunc(df.year / 100)\n        elif what in {\"DAY\", \"DAYS\"}:\n            return df.day\n        elif what in {\"DECADE\", \"DECADES\"}:\n            return da.trunc(df.year / 10)\n        elif what == \"DOW\":\n            return (df.dayofweek + 1) % 7\n        elif what == \"DOY\":\n            return df.dayofyear\n        elif what in {\"HOUR\", \"HOURS\"}:\n            return df.hour\n        elif what in {\"MICROSECOND\", \"MICROSECONDS\"}:\n            return df.microsecond\n        elif what in {\"MILLENIUM\", \"MILLENIUMS\", \"MILLENNIUM\", \"MILLENNIUMS\"}:\n            return da.trunc(df.year / 1000)\n        elif what in {\"MILLISECOND\", \"MILLISECONDS\"}:\n            return da.trunc(1000 * df.microsecond)\n        elif what in {\"MINUTE\", \"MINUTES\"}:\n            return df.minute\n        elif what in {\"MONTH\", \"MONTHS\"}:\n            return df.month\n        elif what in {\"QUARTER\", \"QUARTERS\"}:\n            return df.quarter\n        elif what in {\"SECOND\", \"SECONDS\"}:\n            return df.second\n        elif what in {\"WEEK\", \"WEEKS\"}:\n            return df.isocalendar().week\n        elif what in {\"YEAR\", \"YEARS\"}:\n            return df.year\n        elif what == \"DATE\":\n            return (\n                df.date()\n                if isinstance(df, pd.Timestamp)\n                else dd.to_datetime(df.strftime(\"%Y-%m-%d\"))\n            )\n        else:\n            raise NotImplementedError(f\"Extraction of {what} is not (yet) implemented.\")\n\n\nclass BetweenOperation(Operation):\n    \"\"\"\n    Function for finding rows between two scalar values\n    \"\"\"\n\n    needs_rex = True\n\n    def __init__(self):\n        super().__init__(self.between)\n\n    def between(self, series: dd.Series, low, high, rex=None):\n        return (\n            ~series.between(low, high, inclusive=\"both\")\n            if rex.isNegated()\n            else series.between(low, high, inclusive=\"both\")\n        )\n\n\nclass InListOperation(Operation):\n    \"\"\"\n    Returns a boolean of whether an expression is/isn't in a set of values\n    \"\"\"\n\n    needs_rex = True\n\n    def __init__(self):\n        super().__init__(self.inList)\n\n    def inList(self, series: dd.Series, *operands, rex=None):\n        result = series.isin(operands)\n        return ~result if rex.isNegated() else result\n\n\nclass InSubqueryOperation(Operation):\n    \"\"\"\n    Returns a boolean of whether an expression is/isn't in a Subquery Expression result\n    \"\"\"\n\n    needs_rex = True\n    needs_context = True\n    needs_rel = True\n\n    def __init__(self):\n        super().__init__(self.inSubquery)\n\n    def inSubquery(\n        self, series: dd.Series, *operands, rel=None, rex=None, context=None\n    ):\n        sub_rel = rex.getSubqueryLogicalPlan()\n        dc = RelConverter.convert(sub_rel, context=context)\n\n        # Extract the specified column/Series from the Dataframe\n        fq_column_name = rex.column_name(rel).split(\".\")\n\n        # FIXME: dask's isin doesn't support dask frames as arguments\n        # so we need to compute here\n        col = dc.df[fq_column_name[-1]].compute()\n\n        warnings.warn(\n            \"Dask doesn't support Dask frames as input for .isin, so we must force an early computation\",\n            ResourceWarning,\n        )\n\n        return series.isin(col)\n\n\nclass RexCallPlugin(BaseRexPlugin):\n    \"\"\"\n    RexCall is used for expressions, which calculate something.\n    An example is\n\n        SELECT a + b FROM ...\n\n    but also\n\n        a > 3\n\n    Typically, a RexCall has inputs (which can be RexNodes again)\n    and calls a function on these inputs.\n    The inputs can either be a column or a scalar value.\n    \"\"\"\n\n    class_name = \"RexCall\"\n\n    OPERATION_MAPPING = {\n        # \"binary\" functions\n        \"between\": BetweenOperation(),\n        \"and\": ReduceOperation(operation=operator.and_),\n        \"or\": ReduceOperation(operation=operator.or_),\n        \">\": ReduceOperation(operation=operator.gt),\n        \">=\": ReduceOperation(operation=operator.ge),\n        \"<\": ReduceOperation(operation=operator.lt),\n        \"<=\": ReduceOperation(operation=operator.le),\n        \"=\": ReduceOperation(operation=operator.eq),\n        \"!=\": ReduceOperation(operation=operator.ne),\n        \"<>\": ReduceOperation(operation=operator.ne),\n        \"+\": ReduceOperation(operation=operator.add, unary_operation=lambda x: x),\n        \"-\": ReduceOperation(operation=operator.sub, unary_operation=lambda x: -x),\n        \"/\": ReduceOperation(operation=SQLDivisionOperator()),\n        \"*\": ReduceOperation(operation=operator.mul),\n        \"is distinct from\": NotOperation().of(IsNotDistinctOperation()),\n        \"is not distinct from\": IsNotDistinctOperation(),\n        \"/int\": IntDivisionOperator(),\n        # special operations\n        \"cast\": CastOperation(),\n        \"case\": CaseOperation(),\n        \"not like\": NotOperation().of(LikeOperation(case_sensitive=True)),\n        \"like\": LikeOperation(case_sensitive=True),\n        \"not ilike\": NotOperation().of(LikeOperation(case_sensitive=False)),\n        \"ilike\": LikeOperation(case_sensitive=False),\n        \"not similar to\": NotOperation().of(SimilarOperation()),\n        \"similar to\": SimilarOperation(),\n        \"negative\": NegativeOperation(),\n        \"not\": NotOperation(),\n        \"in list\": InListOperation(),\n        \"in subquery\": InSubqueryOperation(),\n        \"is null\": IsNullOperation(),\n        \"is not null\": NotOperation().of(IsNullOperation()),\n        \"is true\": IsTrueOperation(),\n        \"is not true\": NotOperation().of(IsTrueOperation()),\n        \"is false\": IsFalseOperation(),\n        \"is not false\": NotOperation().of(IsFalseOperation()),\n        \"is unknown\": IsNullOperation(),\n        \"is not unknown\": NotOperation().of(IsNullOperation()),\n        \"rand\": RandOperation(),\n        \"random\": RandOperation(),\n        \"rand_integer\": RandIntegerOperation(),\n        \"search\": SearchOperation(),\n        # Unary math functions\n        \"abs\": TensorScalarOperation(lambda x: x.abs(), np.abs),\n        \"acos\": Operation(da.arccos),\n        \"asin\": Operation(da.arcsin),\n        \"atan\": Operation(da.arctan),\n        \"atan2\": Operation(da.arctan2),\n        \"cbrt\": Operation(da.cbrt),\n        \"ceil\": CeilFloorOperation(\"ceil\"),\n        \"cos\": Operation(da.cos),\n        \"cot\": Operation(lambda x: 1 / da.tan(x)),\n        \"degrees\": Operation(da.degrees),\n        \"exp\": Operation(da.exp),\n        \"floor\": CeilFloorOperation(\"floor\"),\n        \"log10\": Operation(da.log10),\n        \"ln\": Operation(da.log),\n        \"mod\": Operation(da.mod),\n        \"power\": Operation(da.power),\n        \"radians\": Operation(da.radians),\n        \"round\": TensorScalarOperation(lambda x, *ops: x.round(*ops), np.round),\n        \"sign\": Operation(da.sign),\n        \"sin\": Operation(da.sin),\n        \"tan\": Operation(da.tan),\n        \"truncate\": Operation(da.trunc),\n        # string operations\n        \"||\": ReduceOperation(operation=operator.add),\n        \"concat\": ReduceOperation(operation=operator.add),\n        \"characterlength\": TensorScalarOperation(\n            lambda x: x.str.len(), lambda x: len(x)\n        ),\n        \"character_length\": TensorScalarOperation(\n            lambda x: x.str.len(), lambda x: len(x)\n        ),\n        \"upper\": TensorScalarOperation(lambda x: x.str.upper(), lambda x: x.upper()),\n        \"lower\": TensorScalarOperation(lambda x: x.str.lower(), lambda x: x.lower()),\n        \"position\": PositionOperation(),\n        \"trim\": TrimOperation(),\n        \"ltrim\": TrimOperation(\"LEADING\"),\n        \"rtrim\": TrimOperation(\"TRAILING\"),\n        \"btrim\": TrimOperation(\"BOTH\"),\n        \"overlay\": OverlayOperation(),\n        \"substr\": SubStringOperation(),\n        \"substring\": SubStringOperation(),\n        \"initcap\": TensorScalarOperation(lambda x: x.str.title(), lambda x: x.title()),\n        \"coalesce\": CoalesceOperation(),\n        \"replace\": ReplaceOperation(),\n        # date/time operations\n        \"extract_date\": ExtractOperation(),\n        \"localtime\": Operation(lambda *args: pd.Timestamp.now()),\n        \"localtimestamp\": Operation(lambda *args: pd.Timestamp.now()),\n        \"current_time\": Operation(lambda *args: pd.Timestamp.now()),\n        \"current_date\": Operation(lambda *args: pd.Timestamp.now()),\n        \"current_timestamp\": Operation(lambda *args: pd.Timestamp.now()),\n        \"last_day\": TensorScalarOperation(\n            lambda x: x + pd.tseries.offsets.MonthEnd(1),\n            lambda x: convert_to_datetime(x) + pd.tseries.offsets.MonthEnd(1),\n        ),\n        \"dsql_totimestamp\": ToTimestampOperation(),\n        # Temporary UDF functions that need to be moved after this POC\n        \"datepart\": ExtractOperation(),\n        \"date_part\": ExtractOperation(),\n        \"year\": YearOperation(),\n        \"timestampadd\": TimeStampAddOperation(),\n        \"timestampceil\": CeilFloorOperation(\"ceil\"),\n        \"timestampfloor\": CeilFloorOperation(\"floor\"),\n        \"timestampdiff\": DatetimeSubOperation(),\n    }\n\n    def convert(\n        self,\n        rel: \"LogicalPlan\",\n        expr: \"Expression\",\n        dc: DataContainer,\n        context: \"dask_sql.Context\",\n    ) -> SeriesOrScalar:\n\n        # Prepare the operands by turning the RexNodes into python expressions\n        operands = [\n            RexConverter.convert(rel, o, dc, context=context)\n            for o in expr.getOperands()\n        ]\n\n        # FIXME: cuDF doesn't support binops between decimal columns and numpy ints / floats\n        if dask_config.get(\"sql.mappings.decimal_support\") == \"cudf\" and any(\n            str(getattr(o, \"dtype\", None)) == \"decimal128\" for o in operands\n        ):\n            from decimal import Decimal\n\n            operands = [\n                Decimal(str(o))\n                if isinstance(o, float)\n                else o.item()\n                if np.isscalar(o) and pd.api.types.is_integer_dtype(o)\n                else o\n                for o in operands\n            ]\n\n        # Now use the operator name in the mapping\n        schema_name = context.schema_name\n        operator_name = expr.getOperatorName().lower()\n\n        try:\n            operation = self.OPERATION_MAPPING[operator_name]\n        except KeyError:\n            try:\n                operation = context.schema[schema_name].functions[operator_name]\n            except KeyError:  # pragma: no cover\n                raise NotImplementedError(\n                    f\"RexCall operator '{operator_name}' not (yet) implemented\"\n                )\n\n        logger.debug(\n            f\"Executing {operator_name} on {[str(LoggableDataFrame(df)) for df in operands]}\"\n        )\n\n        kwargs = {}\n\n        if Operation.op_needs_dc(operation):\n            kwargs[\"dc\"] = dc\n        if Operation.op_needs_rex(operation):\n            kwargs[\"rex\"] = expr\n        if Operation.op_needs_context(operation):\n            kwargs[\"context\"] = context\n        if Operation.op_needs_rel(operation):\n            kwargs[\"rel\"] = rel\n\n        return operation(*operands, **kwargs)\n        # TODO: We have information on the typing here - we should use it\n"
  },
  {
    "path": "dask_sql/physical/rex/core/input_ref.py",
    "content": "from typing import TYPE_CHECKING\n\nimport dask.dataframe as dd\n\nfrom dask_sql.datacontainer import DataContainer\nfrom dask_sql.physical.rex.base import BaseRexPlugin\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql._datafusion_lib import Expression, LogicalPlan\n\n\nclass RexInputRefPlugin(BaseRexPlugin):\n    \"\"\"\n    A RexInputRef is an expression, which references a single column.\n    It is typically to be found in any expressions which\n    calculate a function in a column of a table.\n    \"\"\"\n\n    class_name = \"InputRef\"\n\n    def convert(\n        self,\n        rel: \"LogicalPlan\",\n        rex: \"Expression\",\n        dc: DataContainer,\n        context: \"dask_sql.Context\",\n    ) -> dd.Series:\n        df = dc.df\n        cc = dc.column_container\n\n        # The column is references by index\n        index = rex.getIndex()\n        backend_column_name = cc.get_backend_by_frontend_index(index)\n        return df[backend_column_name]\n"
  },
  {
    "path": "dask_sql/physical/rex/core/literal.py",
    "content": "import logging\nfrom datetime import datetime\nfrom typing import TYPE_CHECKING, Any\n\nimport dask.dataframe as dd\nimport numpy as np\n\nfrom dask_sql._datafusion_lib import SqlTypeName\nfrom dask_sql.datacontainer import DataContainer\nfrom dask_sql.mappings import sql_to_python_value\nfrom dask_sql.physical.rex.base import BaseRexPlugin\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql._datafusion_lib import Expression, LogicalPlan\n\nlogger = logging.getLogger(__name__)\n\n\nclass SargPythonImplementation:\n    \"\"\"\n    Apache Calcite comes with a Sarg literal, which stands for the\n    \"search arguments\" (which are later used in a SEARCH call).\n    We transform it into a more manageable python object\n    by extracting the Java properties.\n    \"\"\"\n\n    class Range:\n        \"\"\"Helper class to represent one of the ranges in a Sarg object\"\"\"\n\n        # def __init__(self, range: com.google.common.collect.Range, literal_type: str):\n        #     self.lower_endpoint = None\n        #     self.lower_open = True\n        #     if range.hasLowerBound():\n        #         self.lower_endpoint = sql_to_python_value(\n        #             literal_type, range.lowerEndpoint()\n        #         )\n        #         self.lower_open = (\n        #             range.lowerBoundType() == com.google.common.collect.BoundType.OPEN\n        #         )\n\n        #     self.upper_endpoint = None\n        #     self.upper_open = True\n        #     if range.hasUpperBound():\n        #         self.upper_endpoint = sql_to_python_value(\n        #             literal_type, range.upperEndpoint()\n        #         )\n        #         self.upper_open = (\n        #             range.upperBoundType() == com.google.common.collect.BoundType.OPEN\n        #         )\n\n        def filter_on(self, series: dd.Series):\n            lower_condition = True\n            if self.lower_endpoint is not None:\n                if self.lower_open:\n                    lower_condition = self.lower_endpoint < series\n                else:\n                    lower_condition = self.lower_endpoint <= series\n\n            upper_condition = True\n            if self.upper_endpoint is not None:\n                if self.upper_open:\n                    upper_condition = self.upper_endpoint > series\n                else:\n                    upper_condition = self.upper_endpoint >= series\n\n            return lower_condition & upper_condition\n\n        def __repr__(self) -> str:\n            return f\"Range {self.lower_endpoint} - {self.upper_endpoint}\"\n\n    # def __init__(self, java_sarg: org.apache.calcite.util.Sarg, literal_type: str):\n    #     self.ranges = [\n    #         SargPythonImplementation.Range(r, literal_type)\n    #         for r in java_sarg.rangeSet.asRanges()\n    #     ]\n\n    def __repr__(self) -> str:\n        return \",\".join(map(str, self.ranges))\n\n\nclass RexLiteralPlugin(BaseRexPlugin):\n    \"\"\"\n    A RexLiteral in an expression stands for a bare single value.\n    The task of this class is therefore just to extract this\n    value from the java instance and convert it\n    into the correct python type.\n    It is typically used when specifying a literal in a SQL expression,\n    e.g. in a filter.\n    \"\"\"\n\n    class_name = \"RexLiteral\"\n\n    def convert(\n        self,\n        rel: \"LogicalPlan\",\n        rex: \"Expression\",\n        dc: DataContainer,\n        context: \"dask_sql.Context\",\n    ) -> Any:\n        literal_type = str(rex.getType())\n\n        # Call the Rust function to get the actual value and convert the Rust\n        # type name back to a SQL type\n        if literal_type == \"Boolean\":\n            try:\n                literal_type = SqlTypeName.BOOLEAN\n                literal_value = rex.getBoolValue()\n            except TypeError:\n                literal_type = SqlTypeName.NULL\n                literal_value = None\n        elif literal_type == \"Float32\":\n            literal_type = SqlTypeName.FLOAT\n            literal_value = rex.getFloat32Value()\n        elif literal_type == \"Float64\":\n            literal_type = SqlTypeName.DOUBLE\n            literal_value = rex.getFloat64Value()\n        elif literal_type == \"Decimal128\":\n            literal_type = SqlTypeName.DECIMAL\n            value, _, scale = rex.getDecimal128Value()\n            literal_value = value / (10**scale)\n        elif literal_type == \"UInt8\":\n            literal_type = SqlTypeName.TINYINT\n            literal_value = rex.getUInt8Value()\n        elif literal_type == \"UInt16\":\n            literal_type = SqlTypeName.SMALLINT\n            literal_value = rex.getUInt16Value()\n        elif literal_type == \"UInt32\":\n            literal_type = SqlTypeName.INTEGER\n            literal_value = rex.getUInt32Value()\n        elif literal_type == \"UInt64\":\n            literal_type = SqlTypeName.BIGINT\n            literal_value = rex.getUInt64Value()\n        elif literal_type == \"Int8\":\n            literal_type = SqlTypeName.TINYINT\n            literal_value = rex.getInt8Value()\n        elif literal_type == \"Int16\":\n            literal_type = SqlTypeName.SMALLINT\n            literal_value = rex.getInt16Value()\n        elif literal_type == \"Int32\":\n            literal_type = SqlTypeName.INTEGER\n            literal_value = rex.getInt32Value()\n        elif literal_type == \"Int64\":\n            literal_type = SqlTypeName.BIGINT\n            literal_value = rex.getInt64Value()\n        elif literal_type == \"Utf8\":\n            literal_type = SqlTypeName.VARCHAR\n            literal_value = rex.getStringValue()\n        elif literal_type == \"Date32\":\n            literal_type = SqlTypeName.DATE\n            literal_value = np.datetime64(rex.getDate32Value(), \"D\")\n        elif literal_type == \"Date64\":\n            literal_type = SqlTypeName.DATE\n            literal_value = np.datetime64(rex.getDate64Value(), \"ms\")\n        elif literal_type == \"Time64\":\n            literal_value = np.datetime64(rex.getTime64Value(), \"ns\")\n            literal_type = SqlTypeName.TIME\n        elif literal_type == \"Null\":\n            literal_type = SqlTypeName.NULL\n            literal_value = None\n        elif literal_type == \"IntervalDayTime\":\n            literal_type = SqlTypeName.INTERVAL_DAY\n            literal_value = rex.getIntervalDayTimeValue()\n        elif literal_type == \"IntervalMonthDayNano\":\n            literal_type = SqlTypeName.INTERVAL_MONTH_DAY_NANOSECOND\n            literal_value = rex.getIntervalMonthDayNanoValue()\n        elif literal_type in {\n            \"TimestampSecond\",\n            \"TimestampMillisecond\",\n            \"TimestampMicrosecond\",\n            \"TimestampNanosecond\",\n        }:\n            unit_mapping = {\n                \"TimestampSecond\": \"s\",\n                \"TimestampMillisecond\": \"ms\",\n                \"TimestampMicrosecond\": \"us\",\n                \"TimestampNanosecond\": \"ns\",\n            }\n            numpy_unit = unit_mapping.get(literal_type)\n            literal_value, timezone = rex.getTimestampValue()\n            if timezone and timezone != \"UTC\":\n                raise ValueError(\"Non UTC timezones not supported\")\n            elif timezone is None:\n                literal_value = datetime.fromtimestamp(literal_value // 10**9)\n                literal_value = str(literal_value)\n            literal_type = SqlTypeName.TIMESTAMP\n            literal_value = np.datetime64(literal_value, numpy_unit)\n        else:\n            raise RuntimeError(\n                f\"Failed to map literal type {literal_type} to python type in literal.py\"\n            )\n\n        # if isinstance(literal_value, org.apache.calcite.util.Sarg):\n        #     return SargPythonImplementation(literal_value, literal_type)\n\n        python_value = sql_to_python_value(literal_type, literal_value)\n        logger.debug(\n            f\"literal.py python_value: {python_value} or Python type: {type(python_value)}\"\n        )\n\n        return python_value\n"
  },
  {
    "path": "dask_sql/physical/rex/core/subquery.py",
    "content": "from typing import TYPE_CHECKING\n\nimport dask.dataframe as dd\n\nfrom dask_sql.datacontainer import DataContainer\nfrom dask_sql.physical.rel import RelConverter\nfrom dask_sql.physical.rex.base import BaseRexPlugin\n\nif TYPE_CHECKING:\n    import dask_sql\n    from dask_sql._datafusion_lib import Expression, LogicalPlan\n\n\nclass RexScalarSubqueryPlugin(BaseRexPlugin):\n    \"\"\"\n    A RexScalarSubqueryPlugin is an expression, which references a Subquery.\n    This plugin is thin on logic, however keeping with previous patterns\n    we use the plugin approach instead of placing the logic inline\n    \"\"\"\n\n    class_name = \"ScalarSubquery\"\n\n    def convert(\n        self,\n        rel: \"LogicalPlan\",\n        rex: \"Expression\",\n        dc: DataContainer,\n        context: \"dask_sql.Context\",\n    ) -> dd.DataFrame:\n\n        # Extract the LogicalPlan from the Expr instance\n        sub_rel = rex.getSubqueryLogicalPlan()\n\n        dc = RelConverter.convert(sub_rel, context=context)\n        return dc.df\n"
  },
  {
    "path": "dask_sql/physical/utils/__init__.py",
    "content": ""
  },
  {
    "path": "dask_sql/physical/utils/filter.py",
    "content": "from __future__ import annotations\n\nimport itertools\nimport logging\nimport operator\n\nimport dask.dataframe as dd\nimport numpy as np\nfrom dask.blockwise import Blockwise\nfrom dask.highlevelgraph import HighLevelGraph, MaterializedLayer\nfrom dask.layers import DataFrameIOLayer\nfrom dask.utils import M, apply, is_arraylike\n\nlogger = logging.getLogger(__name__)\n\n\ndef attempt_predicate_pushdown(\n    ddf: dd.DataFrame,\n    preserve_filters: bool = True,\n    extract_filters: bool = True,\n    add_filters: list | tuple | DNF | None = None,\n) -> dd.DataFrame:\n    \"\"\"Use graph information to update IO-level filters\n\n    The original `ddf` will be returned if/when the\n    predicate-pushdown optimization fails.\n\n    This is a special optimization that must be called\n    eagerly on a DataFrame collection when filters are\n    applied. The \"eager\" requirement for this optimization\n    is due to the fact that `npartitions` and `divisions`\n    may change when this optimization is applied (invalidating\n    npartition/divisions-specific logic in following Layers).\n\n    Parameters\n    ----------\n    ddf\n        Dask-DataFrame target for predicate pushdown.\n    preserve_filters\n        Whether to preserve pre-existing filters in the case that either\n        `add_filters` is specified, or `extract_filters` is `True` and\n        filters are successfully extracted from `ddf`. Default is `True`.\n    extract_filters\n        Whether to extract filters from the task graph of `ddf`. Default\n        is `True`.\n    add_filters\n        Custom filters to manually add to the IO layer of `ddf`.\n    \"\"\"\n\n    if not (extract_filters or add_filters):\n        # Not extracting filters from the graph or\n        # manually adding user-defined filters. Return\n        return ddf\n\n    # Check that we have a supported `ddf` object\n    if not isinstance(ddf, dd.DataFrame):\n        raise ValueError(\n            f\"Predicate pushdown optimization skipped. Type {type(ddf)} \"\n            f\"does not support predicate pushdown.\"\n        )\n    elif not isinstance(ddf.dask, HighLevelGraph):\n        logger.warning(\n            f\"Predicate pushdown optimization skipped. Graph must be \"\n            f\"a HighLevelGraph object (got {type(ddf.dask)}).\"\n        )\n        return ddf\n\n    # We were able to extract a DNF filter expression.\n    # Check that we have a single IO layer with `filters` support\n    io_layer = []\n    for k, v in ddf.dask.layers.items():\n        if isinstance(v, DataFrameIOLayer):\n            io_layer.append(k)\n            creation_info = (\n                (v.creation_info or {}) if hasattr(v, \"creation_info\") else {}\n            )\n            if \"filters\" not in creation_info.get(\"kwargs\", {}):\n                # No filters support\n                return ddf\n    if len(io_layer) != 1:\n        # Not a single IO layer\n        return ddf\n    io_layer = io_layer.pop()\n\n    # Get pre-existing filters\n    existing_filters = (\n        ddf.dask.layers[io_layer].creation_info.get(\"kwargs\", {}).get(\"filters\")\n    )\n\n    # Start by converting the HLG to a `RegenerableGraph`.\n    # Succeeding here means that all layers in the graph\n    # are regenerable.\n    try:\n        dsk = RegenerableGraph.from_hlg(ddf.dask)\n    except (ValueError, TypeError):\n        logger.warning(\n            \"Predicate pushdown optimization skipped. One or more \"\n            \"layers in the HighLevelGraph was not 'regenerable'.\"\n        )\n        return ddf\n\n    name = ddf._name\n    extracted_filters = DNF(None)\n    if extract_filters:\n        # Extract a DNF-formatted filter expression\n        try:\n            extracted_filters = dsk.layers[name]._dnf_filter_expression(dsk)\n        except (ValueError, TypeError):\n            # DNF dispatching failed for 1+ layers\n            logger.warning(\n                \"Predicate pushdown optimization skipped. One or more \"\n                \"layers has an unknown filter expression.\"\n            )\n\n    # Combine filters\n    filters = DNF(None)\n    if preserve_filters:\n        filters = filters.combine(existing_filters)\n    if extract_filters:\n        filters = filters.combine(extracted_filters)\n    if add_filters:\n        filters = filters.combine(add_filters)\n    if not filters:\n        # No filters encountered\n        return ddf\n    filters = filters.to_list_tuple()\n\n    # FIXME: pyarrow doesn't seem to like converting datetime64[D] to scalars\n    # so we must convert any we encounter to datetime64[ns]\n    filters = [\n        [\n            (\n                col,\n                op,\n                val.astype(\"datetime64[ns]\")\n                if isinstance(val, np.datetime64) and val.dtype == \"datetime64[D]\"\n                else val,\n            )\n            for col, op, val in sublist\n        ]\n        for sublist in filters\n    ]\n\n    # Regenerate collection with filtered IO layer\n    try:\n        _regen_cache = {}\n        return dsk.layers[name]._regenerate_collection(\n            dsk,\n            # TODO: shouldn't need to specify index=False after dask#9661 is merged\n            new_kwargs={io_layer: {\"filters\": filters, \"index\": False}},\n            _regen_cache=_regen_cache,\n        )\n    except ValueError as err:\n        # Most-likely failed to apply filters in read_parquet.\n        # We can just bail on predicate pushdown, but we also\n        # raise a warning to encourage the user to file an issue.\n        logger.warning(\n            f\"Predicate pushdown failed to apply filters: {filters}. \"\n            f\"Please open a bug report at \"\n            f\"https://github.com/dask-contrib/dask-sql/issues/new/choose \"\n            f\"and include the following error message: {err}\"\n        )\n\n        return ddf\n\n\nclass DNF:\n    \"\"\"Manage filters in Disjunctive Normal Form (DNF)\"\"\"\n\n    class _Or(frozenset):\n        \"\"\"Fozen set of disjunctions\"\"\"\n\n        def to_list_tuple(self) -> list:\n            # DNF \"or\" is List[List[Tuple]]\n            def _maybe_list(val):\n                if isinstance(val, tuple) and val and isinstance(val[0], (tuple, list)):\n                    return list(val)\n                return [val]\n\n            return [\n                _maybe_list(val.to_list_tuple())\n                if hasattr(val, \"to_list_tuple\")\n                else _maybe_list(val)\n                for val in self\n            ]\n\n    class _And(frozenset):\n        \"\"\"Frozen set of conjunctions\"\"\"\n\n        def to_list_tuple(self) -> list:\n            # DNF \"and\" is List[Tuple]\n            return tuple(\n                val.to_list_tuple() if hasattr(val, \"to_list_tuple\") else val\n                for val in self\n            )\n\n    _filters: _And | _Or | None  # Underlying filter expression\n\n    def __init__(self, filters: DNF | _And | _Or | list | tuple | None) -> DNF:\n        if isinstance(filters, DNF):\n            self._filters = filters._filters\n        else:\n            self._filters = self.normalize(filters)\n\n    def to_list_tuple(self) -> list:\n        return self._filters.to_list_tuple()\n\n    def __bool__(self) -> bool:\n        return bool(self._filters)\n\n    @classmethod\n    def normalize(cls, filters: _And | _Or | list | tuple | None):\n        \"\"\"Convert raw filters to the `_Or(_And)` DNF representation\"\"\"\n\n        def _valid_tuple(predicate: tuple):\n            col, op, val = predicate\n            if isinstance(col, tuple):\n                raise TypeError(\"filters must be List[Tuple] or List[List[Tuple]]\")\n            if op in (\"in\", \"not in\"):\n                return (col, op, tuple(val))\n            else:\n                return predicate\n\n        def _valid_list(conjunction: list):\n            valid = []\n            for predicate in conjunction:\n                if not isinstance(predicate, tuple):\n                    raise TypeError(f\"Predicate must be a tuple, got {predicate}\")\n                valid.append(_valid_tuple(predicate))\n            return valid\n\n        if not filters:\n            result = None\n        elif isinstance(filters, list):\n            conjunctions = filters if isinstance(filters[0], list) else [filters]\n            result = cls._Or(\n                [cls._And(_valid_list(conjunction)) for conjunction in conjunctions]\n            )\n        elif isinstance(filters, tuple):\n            result = cls._Or((cls._And((_valid_tuple(filters),)),))\n        elif isinstance(filters, cls._Or):\n            result = cls._Or(se for e in filters for se in cls.normalize(e))\n        elif isinstance(filters, cls._And):\n            total = []\n            for c in itertools.product(*[cls.normalize(e) for e in filters]):\n                total.append(cls._And(se for e in c for se in e))\n            result = cls._Or(total)\n        else:\n            raise TypeError(f\"{type(filters)} not a supported type for DNF\")\n        return result\n\n    def combine(self, other: DNF | _And | _Or | list | tuple | None) -> DNF:\n        \"\"\"Combine with another DNF object\"\"\"\n        if not isinstance(other, DNF):\n            other = DNF(other)\n        assert isinstance(other, DNF)\n        if self._filters is None:\n            result = other._filters\n        elif other._filters is None:\n            result = self._filters\n        else:\n            result = self._And([self._filters, other._filters])\n        return DNF(result)\n\n\n# Define all supported comparison functions\n# (and their mapping to a string expression)\n_comparison_symbols = {\n    operator.eq: \"==\",\n    operator.ne: \"!=\",\n    operator.lt: \"<\",\n    operator.le: \"<=\",\n    operator.gt: \">\",\n    operator.ge: \">=\",\n    np.greater: \">\",\n    np.greater_equal: \">=\",\n    np.less: \"<\",\n    np.less_equal: \"<=\",\n    np.equal: \"==\",\n    np.not_equal: \"!=\",\n}\n\n# Define all regenerable \"pass-through\" ops\n# that do not affect filters.\n_pass_through_ops = {M.fillna, M.astype}\n\n# Define set of all \"regenerable\" operations.\n# Predicate pushdown is supported for graphs\n# comprised of `Blockwise` layers based on these\n# operations\n_regenerable_ops = (\n    set(_comparison_symbols.keys())\n    | {\n        operator.and_,\n        operator.or_,\n        operator.getitem,\n        operator.inv,\n        M.isin,\n        M.isna,\n    }\n    | _pass_through_ops\n)\n\n# Specify functions that must be generated with\n# a different API at the dataframe-collection level\n_special_op_mappings = {\n    M.fillna: dd.DataFrame.fillna,\n    M.isin: dd.DataFrame.isin,\n    M.isna: dd.DataFrame.isna,\n    M.astype: dd.DataFrame.astype,\n}\n\n# Convert _pass_through_ops to respect \"special\" mappings\n_pass_through_ops = {_special_op_mappings.get(op, op) for op in _pass_through_ops}\n\n\ndef _preprocess_layers(input_layers):\n    # NOTE: This is a Layer-specific work-around to deal with\n    # the fact that `dd.DataFrame.isin(values)` will add a distinct\n    # `MaterializedLayer` for the `values` argument.\n    # See: https://github.com/dask-contrib/dask-sql/issues/607\n    skip = set()\n    layers = input_layers.copy()\n    for key, layer in layers.items():\n        if key.startswith(\"isin-\") and isinstance(layer, Blockwise):\n            indices = list(layer.indices)\n            for i, (k, ind) in enumerate(layer.indices):\n                if (\n                    ind is None\n                    and isinstance(layers.get(k), MaterializedLayer)\n                    and isinstance(layers[k].get(k), (np.ndarray, tuple))\n                ):\n                    # Replace `indices[i]` with a literal value and\n                    # make sure we skip the `MaterializedLayer` that\n                    # we are now fusing into the `isin`\n                    value = layers[k][k]\n                    value = value[0](*value[1:]) if callable(value[0]) else value\n                    indices[i] = (value, None)\n                    skip.add(k)\n            layer.indices = tuple(indices)\n    return {k: v for k, v in layers.items() if k not in skip}\n\n\nclass RegenerableLayer:\n    \"\"\"Regenerable Layer\n\n    Wraps ``dask.highlevelgraph.Blockwise`` to ensure that a\n    ``creation_info`` attribute  is defined. This class\n    also defines the necessary methods for recursive\n    layer regeneration and filter-expression generation.\n    \"\"\"\n\n    def __init__(self, layer, creation_info):\n        self.layer = layer  # Original Blockwise layer reference\n        self.creation_info = creation_info\n\n    def _regenerate_collection(\n        self,\n        dsk,\n        new_kwargs: dict = None,\n        _regen_cache: dict = None,\n    ):\n        \"\"\"Regenerate a Dask collection for this layer using the\n        provided inputs and key-word arguments\n        \"\"\"\n\n        # Return regenerated layer if the work was\n        # already done\n        if _regen_cache is None:\n            _regen_cache = {}\n        if self.layer.output in _regen_cache:\n            return _regen_cache[self.layer.output]\n\n        # Recursively generate necessary inputs to\n        # this layer to generate the collection\n        inputs = []\n        for key, ind in self.layer.indices:\n            if ind is None:\n                if isinstance(key, (str, tuple)) and key in dsk.layers:\n                    continue\n                inputs.append(key)\n            elif key in self.layer.io_deps:\n                continue\n            else:\n                inputs.append(\n                    dsk.layers[key]._regenerate_collection(\n                        dsk,\n                        new_kwargs=new_kwargs,\n                        _regen_cache=_regen_cache,\n                    )\n                )\n\n        # Extract the callable func and key-word args.\n        # Then return a regenerated collection\n        func = self.creation_info.get(\"func\", None)\n        if func is None:\n            raise ValueError(\n                \"`_regenerate_collection` failed. \"\n                \"Not all HLG layers are regenerable.\"\n            )\n        regen_args = self.creation_info.get(\"args\", [])\n        regen_kwargs = self.creation_info.get(\"kwargs\", {}).copy()\n        regen_kwargs = {k: v for k, v in self.creation_info.get(\"kwargs\", {}).items()}\n        regen_kwargs.update((new_kwargs or {}).get(self.layer.output, {}))\n\n        result = func(*inputs, *regen_args, **regen_kwargs)\n        _regen_cache[self.layer.output] = result\n        return result\n\n    def _dnf_filter_expression(self, dsk):\n        \"\"\"Return a DNF-formatted filter expression for the\n        graph terminating at this layer\n        \"\"\"\n        op = self.creation_info[\"func\"]\n        if op in _comparison_symbols.keys():\n            func = _blockwise_comparison_dnf\n        elif op in (operator.and_, operator.or_):\n            func = _blockwise_logical_dnf\n        elif op == operator.getitem:\n            func = _blockwise_getitem_dnf\n        elif op == dd.DataFrame.isin:\n            func = _blockwise_isin_dnf\n        elif op == dd.DataFrame.isna:\n            func = _blockwise_isna_dnf\n        elif op == operator.inv:\n            func = _blockwise_inv_dnf\n        elif op in _pass_through_ops:\n            func = _blockwise_pass_through_dnf\n        else:\n            raise ValueError(f\"No DNF expression for {op}\")\n\n        return func(op, self.layer.indices, dsk)\n\n\nclass RegenerableGraph:\n    \"\"\"Regenerable Graph\n\n    This class is similar to ``dask.highlevelgraph.HighLevelGraph``.\n    However, all layers in a ``RegenerableGraph`` graph must be\n    ``RegenerableLayer`` objects (which wrap ``Blockwise`` layers).\n    \"\"\"\n\n    def __init__(self, layers: dict):\n        self.layers = layers\n\n    @classmethod\n    def from_hlg(cls, hlg: HighLevelGraph):\n        \"\"\"Construct a ``RegenerableGraph`` from a ``HighLevelGraph``\"\"\"\n\n        if not isinstance(hlg, HighLevelGraph):\n            raise TypeError(f\"Expected HighLevelGraph, got {type(hlg)}\")\n\n        _layers = {}\n        for key, layer in _preprocess_layers(hlg.layers).items():\n            regenerable_layer = None\n            if isinstance(layer, DataFrameIOLayer):\n                regenerable_layer = RegenerableLayer(layer, layer.creation_info or {})\n            elif isinstance(layer, Blockwise):\n                tasks = list(layer.dsk.values())\n                if len(tasks) == 1 and tasks[0]:\n                    kwargs = {}\n                    if tasks[0][0] == apply:\n                        op = tasks[0][1]\n                        options = tasks[0][3]\n                        if isinstance(options, dict):\n                            kwargs = options\n                        elif (\n                            isinstance(options, tuple)\n                            and options\n                            and callable(options[0])\n                        ):\n                            kwargs = options[0](*options[1:])\n                    else:\n                        op = tasks[0][0]\n                    if op in _regenerable_ops:\n                        regenerable_layer = RegenerableLayer(\n                            layer,\n                            {\n                                \"func\": _special_op_mappings.get(op, op),\n                                \"kwargs\": kwargs,\n                            },\n                        )\n\n            if regenerable_layer is None:\n                raise ValueError(f\"Graph contains non-regenerable layer: {layer}\")\n\n            _layers[key] = regenerable_layer\n\n        return RegenerableGraph(_layers)\n\n\ndef _get_blockwise_input(input_index, indices: list, dsk: RegenerableGraph):\n    # Simple utility to get the required input expressions\n    # for a Blockwise layer (using indices)\n    key = indices[input_index][0]\n    if indices[input_index][1] is None:\n        return key\n    return dsk.layers[key]._dnf_filter_expression(dsk)\n\n\ndef _inv(symbol: str):\n    return {\n        \">\": \"<\",\n        \"<\": \">\",\n        \">=\": \"<=\",\n        \"<=\": \">=\",\n        \"in\": \"not in\",\n        \"not in\": \"in\",\n        \"is\": \"is not\",\n        \"is not\": \"is\",\n    }.get(symbol, symbol)\n\n\ndef _blockwise_comparison_dnf(op, indices: list, dsk: RegenerableGraph) -> DNF:\n    # Return DNF expression pattern for a simple comparison\n    left = _get_blockwise_input(0, indices, dsk)\n    right = _get_blockwise_input(1, indices, dsk)\n\n    if is_arraylike(left) and hasattr(left, \"item\") and left.size == 1:\n        left = left.item()\n        # Need inverse comparison in read_parquet\n        return DNF((right, _inv(_comparison_symbols[op]), left))\n    if is_arraylike(right) and hasattr(right, \"item\") and right.size == 1:\n        right = right.item()\n    return DNF((left, _comparison_symbols[op], right))\n\n\ndef _blockwise_logical_dnf(op, indices: list, dsk: RegenerableGraph) -> DNF:\n    # Return DNF expression pattern for logical \"and\" or \"or\"\n    left = _get_blockwise_input(0, indices, dsk)\n    right = _get_blockwise_input(1, indices, dsk)\n\n    filters = []\n    for val in [left, right]:\n        if not isinstance(val, (tuple, DNF)):\n            raise TypeError(f\"Invalid logical operand: {val}\")\n        filters.append(DNF(val)._filters)\n\n    if op == operator.or_:\n        return DNF(DNF._Or(filters))\n    elif op == operator.and_:\n        return DNF(DNF._And(filters))\n    else:\n        raise ValueError\n\n\ndef _blockwise_getitem_dnf(op, indices: list, dsk: RegenerableGraph):\n    # Return dnf of key (selected by getitem)\n    key = _get_blockwise_input(1, indices, dsk)\n    return key\n\n\ndef _blockwise_pass_through_dnf(op, indices: list, dsk: RegenerableGraph):\n    # Return dnf of input collection\n    return _get_blockwise_input(0, indices, dsk)\n\n\ndef _blockwise_isin_dnf(op, indices: list, dsk: RegenerableGraph) -> DNF:\n    # Return DNF expression pattern for a simple \"in\" comparison\n    left = _get_blockwise_input(0, indices, dsk)\n    right = _get_blockwise_input(1, indices, dsk)\n    return DNF((left, \"in\", tuple(right)))\n\n\ndef _blockwise_isna_dnf(op, indices: list, dsk: RegenerableGraph) -> DNF:\n    # Return DNF expression pattern for `isna`\n    left = _get_blockwise_input(0, indices, dsk)\n    return DNF((left, \"is\", None))\n\n\ndef _blockwise_inv_dnf(op, indices: list, dsk: RegenerableGraph) -> DNF:\n    # Return DNF expression pattern for the inverse of a comparison\n    expr = _get_blockwise_input(0, indices, dsk).to_list_tuple()\n    new_expr = []\n    count = 0\n    for conjunction in expr:\n        new_conjunction = []\n        for col, op, val in conjunction:\n            count += 1\n            new_conjunction.append((col, _inv(op), val))\n        new_expr.append(DNF._And(new_conjunction))\n    if count > 1:\n        # Havent taken the time to think through\n        # general inversion yet.\n        raise ValueError(\"inv(DNF) case not implemented.\")\n    return DNF(DNF._Or(new_expr))\n"
  },
  {
    "path": "dask_sql/physical/utils/groupby.py",
    "content": "import dask.dataframe as dd\n\nfrom dask_sql.utils import new_temporary_column\n\n\ndef get_groupby_with_nulls_cols(\n    df: dd.DataFrame, group_columns: list[str], additional_column_name: str = None\n):\n    \"\"\"\n    SQL and dask are treating null columns a bit different:\n    SQL will put them to the front, dask will just ignore them\n    Therefore we use the same trick as fugue does:\n    we will group by both the NaN and the real column value\n    \"\"\"\n    if additional_column_name is None:\n        additional_column_name = new_temporary_column(df)\n\n    group_columns_and_nulls = []\n    for group_column in group_columns:\n        is_null_column = group_column.isnull()\n        non_nan_group_column = group_column.fillna(0)\n\n        # split_out doesn't work if both columns have the same name\n        is_null_column.name = f\"{is_null_column.name}_{new_temporary_column(df)}\"\n\n        group_columns_and_nulls += [is_null_column, non_nan_group_column]\n\n    if not group_columns_and_nulls:\n        # This can happen in statements like\n        # SELECT SUM(x) FROM data\n        # without any groupby statement\n        group_columns_and_nulls = [additional_column_name]\n\n    return group_columns_and_nulls\n"
  },
  {
    "path": "dask_sql/physical/utils/ml_classes.py",
    "content": "def get_cpu_classes():\n    try:\n        from sklearn.utils import all_estimators\n\n        cpu_classes = {\n            k: v.__module__ + \".\" + v.__qualname__ for k, v in all_estimators()\n        }\n    except ImportError:\n        cpu_classes = {}\n\n    cpu_classes = add_boosting_classes(cpu_classes)\n\n    return cpu_classes\n\n\ndef get_gpu_classes():\n    gpu_classes = {\n        # cuml.dask\n        \"DBSCAN\": \"cuml.dask.cluster.dbscan.DBSCAN\",\n        \"KMeans\": \"cuml.dask.cluster.kmeans.KMeans\",\n        \"PCA\": \"cuml.dask.decomposition.pca.PCA\",\n        \"TruncatedSVD\": \"cuml.dask.decomposition.tsvd.TruncatedSVD\",\n        \"RandomForestClassifier\": \"cuml.dask.ensemble.randomforestclassifier.RandomForestClassifier\",\n        \"RandomForestRegressor\": \"cuml.dask.ensemble.randomforestregressor.RandomForestRegressor\",\n        \"LogisticRegression\": \"cuml.dask.extended.linear_model.logistic_regression.LogisticRegression\",\n        \"TfidfTransformer\": \"cuml.dask.feature_extraction.text.tfidf_transformer.TfidfTransformer\",\n        \"LinearRegression\": \"cuml.dask.linear_model.linear_regression.LinearRegression\",\n        \"Ridge\": \"cuml.dask.linear_model.ridge.Ridge\",\n        \"Lasso\": \"cuml.dask.linear_model.lasso.Lasso\",\n        \"ElasticNet\": \"cuml.dask.linear_model.elastic_net.ElasticNet\",\n        \"UMAP\": \"cuml.dask.manifold.umap.UMAP\",\n        \"MultinomialNB\": \"cuml.dask.naive_bayes.naive_bayes.MultinomialNB\",\n        \"NearestNeighbors\": \"cuml.dask.neighbors.nearest_neighbors.NearestNeighbors\",\n        \"KNeighborsClassifier\": \"cuml.dask.neighbors.kneighbors_classifier.KNeighborsClassifier\",\n        \"KNeighborsRegressor\": \"cuml.dask.neighbors.kneighbors_regressor.KNeighborsRegressor\",\n        \"LabelBinarizer\": \"cuml.dask.preprocessing.label.LabelBinarizer\",\n        \"OneHotEncoder\": \"cuml.dask.preprocessing.encoders.OneHotEncoder\",\n        \"LabelEncoder\": \"cuml.dask.preprocessing.LabelEncoder.LabelEncoder\",\n        \"CD\": \"cuml.dask.solvers.cd.CD\",\n        # cuml\n        \"Base\": \"cuml.internals.base.Base\",\n        \"Handle\": \"cuml.common.handle.Handle\",\n        \"AgglomerativeClustering\": \"cuml.cluster.agglomerative.AgglomerativeClustering\",\n        \"HDBSCAN\": \"cuml.cluster.hdbscan.HDBSCAN\",\n        \"IncrementalPCA\": \"cuml.decomposition.incremental_pca.IncrementalPCA\",\n        \"ForestInference\": \"cuml.fil.fil.ForestInference\",\n        \"KernelRidge\": \"cuml.kernel_ridge.kernel_ridge.KernelRidge\",\n        \"MBSGDClassifier\": \"cuml.linear_model.mbsgd_classifier.MBSGDClassifier\",\n        \"MBSGDRegressor\": \"cuml.linear_model.mbsgd_regressor.MBSGDRegressor\",\n        \"TSNE\": \"cuml.manifold.t_sne.TSNE\",\n        \"KernelDensity\": \"cuml.neighbors.kernel_density.KernelDensity\",\n        \"GaussianRandomProjection\": \"cuml.random_projection.random_projection.GaussianRandomProjection\",\n        \"SparseRandomProjection\": \"cuml.random_projection.random_projection.SparseRandomProjection\",\n        \"SGD\": \"cuml.solvers.sgd.SGD\",\n        \"QN\": \"cuml.solvers.qn.QN\",\n        \"SVC\": \"cuml.svm.SVC\",\n        \"SVR\": \"cuml.svm.SVR\",\n        \"LinearSVC\": \"cuml.svm.LinearSVC\",\n        \"LinearSVR\": \"cuml.svm.LinearSVR\",\n        \"ARIMA\": \"cuml.tsa.arima.ARIMA\",\n        \"AutoARIMA\": \"cuml.tsa.auto_arima.AutoARIMA\",\n        \"ExponentialSmoothing\": \"cuml.tsa.holtwinters.ExponentialSmoothing\",\n        # sklearn\n        \"Binarizer\": \"cuml.preprocessing.Binarizer\",\n        \"KernelCenterer\": \"cuml.preprocessing.KernelCenterer\",\n        \"MinMaxScaler\": \"cuml.preprocessing.MinMaxScaler\",\n        \"MaxAbsScaler\": \"cuml.preprocessing.MaxAbsScaler\",\n        \"Normalizer\": \"cuml.preprocessing.Normalizer\",\n        \"PolynomialFeatures\": \"cuml.preprocessing.PolynomialFeatures\",\n        \"PowerTransformer\": \"cuml.preprocessing.PowerTransformer\",\n        \"QuantileTransformer\": \"cuml.preprocessing.QuantileTransformer\",\n        \"RobustScaler\": \"cuml.preprocessing.RobustScaler\",\n        \"StandardScaler\": \"cuml.preprocessing.StandardScaler\",\n        \"SimpleImputer\": \"cuml.preprocessing.SimpleImputer\",\n        \"MissingIndicator\": \"cuml.preprocessing.MissingIndicator\",\n        \"KBinsDiscretizer\": \"cuml.preprocessing.KBinsDiscretizer\",\n        \"FunctionTransformer\": \"cuml.preprocessing.FunctionTransformer\",\n        \"ColumnTransformer\": \"cuml.compose.ColumnTransformer\",\n        \"GridSearchCV\": \"sklearn.model_selection.GridSearchCV\",\n        \"Pipeline\": \"sklearn.pipeline.Pipeline\",\n        # Other\n        \"UniversalBase\": \"cuml.internals.base.UniversalBase\",\n        \"Lars\": \"cuml.experimental.linear_model.lars.Lars\",\n        \"TfidfVectorizer\": \"cuml.feature_extraction._tfidf_vectorizer.TfidfVectorizer\",\n        \"CountVectorizer\": \"cuml.feature_extraction._vectorizers.CountVectorizer\",\n        \"HashingVectorizer\": \"cuml.feature_extraction._vectorizers.HashingVectorizer\",\n        \"StratifiedKFold\": \"cuml.model_selection._split.StratifiedKFold\",\n        \"OneVsOneClassifier\": \"cuml.multiclass.multiclass.OneVsOneClassifier\",\n        \"OneVsRestClassifier\": \"cuml.multiclass.multiclass.OneVsRestClassifier\",\n        \"MulticlassClassifier\": \"cuml.multiclass.multiclass.MulticlassClassifier\",\n        \"BernoulliNB\": \"cuml.naive_bayes.naive_bayes.BernoulliNB\",\n        \"GaussianNB\": \"cuml.naive_bayes.naive_bayes.GaussianNB\",\n        \"ComplementNB\": \"cuml.naive_bayes.naive_bayes.ComplementNB\",\n        \"CategoricalNB\": \"cuml.naive_bayes.naive_bayes.CategoricalNB\",\n        \"TargetEncoder\": \"cuml.preprocessing.TargetEncoder\",\n        \"PorterStemmer\": \"cuml.preprocessing.text.stem.porter_stemmer.PorterStemmer\",\n    }\n\n    gpu_classes = add_boosting_classes(gpu_classes)\n\n    return gpu_classes\n\n\ndef add_boosting_classes(my_classes):\n    my_classes[\"LGBMModel\"] = \"lightgbm.LGBMModel\"\n    my_classes[\"LGBMClassifier\"] = \"lightgbm.LGBMClassifier\"\n    my_classes[\"LGBMRegressor\"] = \"lightgbm.LGBMRegressor\"\n    my_classes[\"LGBMRanker\"] = \"lightgbm.LGBMRanker\"\n    my_classes[\"XGBRegressor\"] = \"xgboost.XGBRegressor\"\n    my_classes[\"XGBClassifier\"] = \"xgboost.XGBClassifier\"\n    my_classes[\"XGBRanker\"] = \"xgboost.XGBRanker\"\n    my_classes[\"XGBRFRegressor\"] = \"xgboost.XGBRFRegressor\"\n    my_classes[\"XGBRFClassifier\"] = \"xgboost.XGBRFClassifier\"\n    my_classes[\"DaskXGBClassifier\"] = \"xgboost.dask.DaskXGBClassifier\"\n    my_classes[\"DaskXGBRegressor\"] = \"xgboost.dask.DaskXGBRegressor\"\n    my_classes[\"DaskXGBRanker\"] = \"xgboost.dask.DaskXGBRanker\"\n    my_classes[\"DaskXGBRFRegressor\"] = \"xgboost.dask.DaskXGBRFRegressor\"\n    my_classes[\"DaskXGBRFClassifier\"] = \"xgboost.dask.DaskXGBRFClassifier\"\n\n    return my_classes\n"
  },
  {
    "path": "dask_sql/physical/utils/sort.py",
    "content": "import dask.dataframe as dd\nimport pandas as pd\nfrom dask import config as dask_config\nfrom dask.utils import M\n\nfrom dask_sql.utils import is_cudf_type\n\n\ndef apply_sort(\n    df: dd.DataFrame,\n    sort_columns: list[str],\n    sort_ascending: list[bool],\n    sort_null_first: list[bool],\n    sort_num_rows: int = None,\n) -> dd.DataFrame:\n    # when sort_values doesn't support lists of ascending / null\n    # position booleans, we can still do the sort provided that\n    # the list(s) are homogeneous:\n    single_ascending = len(set(sort_ascending)) == 1\n    single_null_first = len(set(sort_null_first)) == 1\n\n    if is_topk_optimizable(\n        df=df,\n        sort_columns=sort_columns,\n        single_ascending=single_ascending,\n        sort_null_first=sort_null_first,\n        sort_num_rows=sort_num_rows,\n    ):\n        return topk_sort(\n            df=df,\n            sort_columns=sort_columns,\n            sort_ascending=sort_ascending,\n            sort_num_rows=sort_num_rows,\n        )\n\n    else:\n        # Pre persist before sort to avoid duplicate compute\n        df = df.persist()\n\n    # pandas / cudf don't support lists of null positions\n    if df.npartitions == 1 and single_null_first:\n        return df.map_partitions(\n            M.sort_values,\n            by=sort_columns,\n            ascending=sort_ascending,\n            na_position=\"first\" if sort_null_first[0] else \"last\",\n        ).persist()\n\n    # dask / dask-cudf don't support lists of ascending / null positions\n    if len(sort_columns) == 1 or (\n        is_cudf_type(df) and single_ascending and single_null_first\n    ):\n        try:\n            return df.sort_values(\n                by=sort_columns,\n                ascending=sort_ascending[0],\n                na_position=\"first\" if sort_null_first[0] else \"last\",\n                # ignore_index=True,\n            ).persist()\n        except ValueError:\n            pass\n\n    # if standard `sort_values` can't handle ascending / null position params,\n    # we extend it using our custom sort function\n    return df.sort_values(\n        by=sort_columns[0],\n        ascending=sort_ascending[0],\n        na_position=\"first\" if sort_null_first[0] else \"last\",\n        sort_function=(sort_partition_func),\n        sort_function_kwargs={\n            \"sort_columns\": sort_columns,\n            \"sort_ascending\": sort_ascending,\n            \"sort_null_first\": sort_null_first,\n        },\n    ).persist()\n\n\ndef topk_sort(\n    df: dd.DataFrame,\n    sort_columns: list[str],\n    sort_ascending: list[bool],\n    sort_num_rows: int = None,\n):\n    if sort_ascending[0]:\n        return df.nsmallest(n=sort_num_rows, columns=sort_columns)\n    else:\n        return df.nlargest(n=sort_num_rows, columns=sort_columns)\n\n\ndef sort_partition_func(\n    partition: pd.DataFrame,\n    sort_columns: list[str],\n    sort_ascending: list[bool],\n    sort_null_first: list[bool],\n    **kwargs,\n):\n    if partition.empty:\n        return partition\n\n    # Trick: https://github.com/pandas-dev/pandas/issues/17111\n    # to make sorting faster\n    # With that, we can also allow for different NaN-orders by column\n    # For this, we start with the last sort column\n    # and use mergesort when we move to the front\n    for col, asc, null_first in reversed(\n        list(zip(sort_columns, sort_ascending, sort_null_first))\n    ):\n        if null_first:\n            na_position = \"first\"\n        else:\n            na_position = \"last\"\n\n        partition = partition.sort_values(\n            by=[col], ascending=asc, na_position=na_position, kind=\"mergesort\"\n        )\n\n    return partition\n\n\ndef is_topk_optimizable(\n    df: dd.DataFrame,\n    sort_columns: list[str],\n    single_ascending: bool,\n    sort_null_first: list[bool],\n    sort_num_rows: int = None,\n):\n    if (\n        sort_num_rows is None\n        or not single_ascending\n        or any(sort_null_first)\n        # pandas/cudf don't support nsmallest/nlargest with object dtypes\n        or any(df[sort_columns].dtypes == \"object\")\n        or (\n            sort_num_rows * len(df.columns)\n            > dask_config.get(\"sql.sort.topk-nelem-limit\")\n        )\n    ):\n        return False\n\n    return True\n"
  },
  {
    "path": "dask_sql/physical/utils/statistics.py",
    "content": "from __future__ import annotations\n\nimport itertools\nimport logging\nfrom collections import defaultdict\nfrom functools import lru_cache\n\nimport dask\nimport dask.dataframe as dd\nimport pyarrow.parquet as pq\nfrom dask.dataframe.io.parquet.arrow import ArrowDatasetEngine\nfrom dask.dataframe.io.parquet.core import ParquetFunctionWrapper\nfrom dask.dataframe.io.utils import _is_local_fs\nfrom dask.delayed import delayed\nfrom dask.layers import DataFrameIOLayer\nfrom dask.utils_test import hlg_layer\n\nlogger = logging.getLogger(__name__)\n\n\ndef parquet_statistics(\n    ddf: dd.DataFrame,\n    columns: list | None = None,\n    parallel: int | False | None = None,\n    **compute_kwargs,\n) -> list[dict] | None:\n    \"\"\"Extract Parquet statistics from a Dask DataFrame collection\n\n    WARNING: This API is experimental\n\n    Parameters\n    ----------\n    ddf\n        Dask-DataFrame object to extract Parquet statistics from.\n    columns\n        List of columns to collect min/max statistics for. If ``None``\n        (the default), only 'num-rows' statistics will be collected.\n    parallel\n        The number of distinct files to collect statistics for\n        within a distinct ``dask.delayed`` task. If ``False``, all\n        statistics will be parsed on the client process. If ``None``,\n        the value will be set to 16 for remote filesystem (e.g s3)\n        and ``False`` otherwise. Default is ``None``.\n    **compute_kwargs\n        Key-word arguments to pass through to ``dask.compute`` when\n        ``parallel`` is not ``False``.\n\n    Returns\n    -------\n    statistics\n        List of Parquet statistics. Each list element corresponds\n        to a distinct partition in ``ddf``. Each element of\n        ``statistics`` will correspond to a dictionary with\n        'num-rows' and 'columns' keys::\n\n            ``{'num-rows': 1024, 'columns': [...]}``\n\n        If column statistics are available, each element of the\n        list stored under the \"columns\" key will correspond to\n        a dictionary with \"name\", \"min\", and \"max\" keys::\n\n            ``{'name': 'col0', 'min': 0, 'max': 100}``\n    \"\"\"\n\n    # Check that we have a supported `ddf` object\n    if not isinstance(ddf, dd.DataFrame):\n        raise ValueError(f\"Expected Dask DataFrame, got {type(ddf)}.\")\n\n    # Be strict about columns argument\n    if columns:\n        if not isinstance(columns, list):\n            raise ValueError(f\"Expected columns to be a list, got {type(columns)}.\")\n        elif not set(columns).issubset(set(ddf.columns)):\n            raise ValueError(f\"columns={columns} must be a subset of {ddf.columns}\")\n\n    # Extract \"read-parquet\" layer from ddf\n    try:\n        layer = hlg_layer(ddf.dask, \"read-parquet\")\n    except KeyError:\n        layer = None\n\n    # Make sure we are dealing with a\n    # ParquetFunctionWrapper-based DataFrameIOLayer\n    if not isinstance(layer, DataFrameIOLayer) or not isinstance(\n        layer.io_func, ParquetFunctionWrapper\n    ):\n        logger.debug(\n            f\"Could not extract Parquet statistics from {ddf}.\"\n            f\"\\nAttempted IO layer: {layer}\"\n        )\n        return None\n\n    # Collect statistics using layer information\n    parts = layer.inputs\n    fs = layer.io_func.fs\n    engine = layer.io_func.engine\n    if not issubclass(engine, ArrowDatasetEngine):\n        logger.debug(\n            f\"Could not extract Parquet statistics from {ddf}.\"\n            f\"\\nUnsupported parquet engine: {engine}\"\n        )\n        return None\n\n    # Set default\n    if parallel is None:\n        parallel = False if _is_local_fs(fs) else 16\n    parallel = int(parallel)\n\n    if parallel:\n        # Group parts corresponding to the same file.\n        # A single task should always parse statistics\n        # for all these parts at once (since they will\n        # all be in the same footer)\n        groups = defaultdict(list)\n        for part in parts:\n            for p in [part] if isinstance(part, dict) else part:\n                path = p.get(\"piece\")[0]\n                groups[path].append(p)\n        group_keys = list(groups.keys())\n\n        # Compute and return flattened result\n        func = delayed(_read_partition_stats_group)\n        result = dask.compute(\n            [\n                func(\n                    list(\n                        itertools.chain(\n                            *[groups[k] for k in group_keys[i : i + parallel]]\n                        )\n                    ),\n                    fs,\n                    engine,\n                    columns=columns,\n                )\n                for i in range(0, len(group_keys), parallel)\n            ],\n            **(compute_kwargs or {}),\n        )[0]\n        return list(itertools.chain(*result))\n    else:\n        # Serial computation on client\n        return _read_partition_stats_group(parts, fs, engine, columns=columns)\n\n\ndef _read_partition_stats_group(parts, fs, engine, columns=None):\n    def _read_partition_stats(part, fs, columns=None):\n        # Helper function to read Parquet-metadata\n        # statistics for a single partition\n\n        if not isinstance(part, list):\n            part = [part]\n\n        column_stats = {}\n        num_rows = 0\n        columns = columns or []\n        for p in part:\n            piece = p[\"piece\"]\n            path = piece[0]\n            row_groups = None if piece[1] == [None] else piece[1]\n            md = _get_md(path, fs)\n            if row_groups is None:\n                row_groups = list(range(md.num_row_groups))\n            for rg in row_groups:\n                row_group = md.row_group(rg)\n                num_rows += row_group.num_rows\n                for i in range(row_group.num_columns):\n                    col = row_group.column(i)\n                    name = col.path_in_schema\n                    if name in columns:\n                        if col.statistics and col.statistics.has_min_max:\n                            if name in column_stats:\n                                column_stats[name][\"min\"] = min(\n                                    column_stats[name][\"min\"], col.statistics.min\n                                )\n                                column_stats[name][\"max\"] = max(\n                                    column_stats[name][\"max\"], col.statistics.max\n                                )\n                            else:\n                                column_stats[name] = {\n                                    \"min\": col.statistics.min,\n                                    \"max\": col.statistics.max,\n                                }\n\n        # Convert dict-of-dict to list-of-dict to be consistent\n        # with current `dd.read_parquet` convention (for now)\n        column_stats_list = [\n            {\n                \"name\": name,\n                \"min\": column_stats[name][\"min\"],\n                \"max\": column_stats[name][\"max\"],\n            }\n            for name in column_stats.keys()\n        ]\n        return {\"num-rows\": num_rows, \"columns\": column_stats_list}\n\n    @lru_cache(maxsize=1)\n    def _get_md(path, fs):\n        # Caching utility to avoid parsing the same footer\n        # metadata multiple times\n        with fs.open(path, default_cache=\"none\") as f:\n            return pq.ParquetFile(f).metadata\n\n    # Helper function used by _extract_statistics\n    return [_read_partition_stats(part, fs, columns=columns) for part in parts]\n"
  },
  {
    "path": "dask_sql/server/__init__.py",
    "content": ""
  },
  {
    "path": "dask_sql/server/app.py",
    "content": "import asyncio\nimport logging\nfrom argparse import ArgumentParser\nfrom uuid import uuid4\n\nimport dask.distributed\nimport uvicorn\nfrom fastapi import FastAPI, HTTPException, Request\nfrom uvicorn import Config, Server\n\nfrom dask_sql.context import Context\nfrom dask_sql.server.presto_jdbc import create_meta_data\nfrom dask_sql.server.responses import DataResults, ErrorResults, QueryResults\n\napp = FastAPI()\nlogger = logging.getLogger(__name__)\n\n\n@app.get(\"/v1/empty\")\nasync def empty(request: Request):\n    \"\"\"\n    Helper endpoint returning an empty\n    result.\n    \"\"\"\n    return QueryResults(request=request)\n\n\n@app.delete(\"/v1/cancel/{uuid}\")\nasync def cancel(uuid: str, request: Request):\n    \"\"\"\n    Cancel an already running computation\n    \"\"\"\n    logger.debug(f\"Canceling the request with uuid {uuid}\")\n    try:\n        future = request.app.future_list[uuid]\n    except KeyError:\n        raise HTTPException(status_code=404, detail=\"uuid not found\")\n    future.cancel()\n    del request.app.future_list[uuid]\n\n    return {\"status\": \"ok\"}\n\n\n@app.get(\"/v1/status/{uuid}\")\nasync def status(uuid: str, request: Request):\n    \"\"\"\n    Return the status (or the result) of an already running calculation\n    \"\"\"\n    logger.debug(f\"Accessing the request with uuid {uuid}\")\n    try:\n        future = request.app.future_list[uuid]\n    except KeyError:\n        raise HTTPException(status_code=404, detail=\"uuid not found\")\n\n    if future.done():\n        logger.debug(f\"{uuid} is already finished, returning data\")\n        df = future.result()\n\n        del request.app.future_list[uuid]\n\n        return DataResults(df, request=request)\n\n    logger.debug(f\"{uuid} is not already finished\")\n\n    status_url = str(request.url)\n    return QueryResults(request=request, next_url=status_url)\n\n\n@app.post(\"/v1/statement\")\nasync def query(request: Request):\n    \"\"\"\n    Main endpoint returning query results\n    in the presto on wire format.\n    \"\"\"\n    try:\n        sql = (await request.body()).decode().strip()\n        # required for PrestoDB JDBC driver compatibility\n        # replaces queries to unsupported `system` catalog with queries to `system_jdbc`\n        # schema created by `create_meta_data(context)` when `jdbc_metadata=True`\n        # TODO: explore Trino which should make JDBC compatibility easier but requires\n        # changing response headers (see https://github.com/dask-contrib/dask-sql/pull/351)\n        sql = sql.replace(\"system.jdbc\", \"system_jdbc\")\n        df = request.app.c.sql(sql)\n\n        if df is None:\n            return DataResults(df, request)\n\n        uuid = str(uuid4())\n        request.app.future_list[uuid] = request.app.client.compute(df)\n        logger.debug(f\"Registering {sql} with uuid {uuid}.\")\n\n        status_url = str(\n            request.url.replace(path=request.app.url_path_for(\"status\", uuid=uuid))\n        )\n        cancel_url = str(\n            request.url.replace(path=request.app.url_path_for(\"cancel\", uuid=uuid))\n        )\n        return QueryResults(request=request, next_url=status_url, cancel_url=cancel_url)\n    except Exception as e:\n        return ErrorResults(e, request=request)\n\n\ndef run_server(\n    context: Context = None,\n    client: dask.distributed.Client = None,\n    host: str = \"0.0.0.0\",\n    port: int = 8080,\n    startup=False,\n    log_level=None,\n    blocking: bool = True,\n    jdbc_metadata: bool = False,\n):  # pragma: no cover\n    \"\"\"\n    Run a HTTP server for answering SQL queries using ``dask-sql``.\n    It uses the `Presto Wire Protocol <https://github.com/prestodb/presto/wiki/HTTP-Protocol>`_\n    for communication.\n    This means, it has a single POST endpoint `/v1/statement`, which answers\n    SQL queries (as string in the body) with the output as a JSON\n    (in the format described in the documentation above).\n    Every SQL expression that ``dask-sql`` understands can be used here.\n\n    See :ref:`server` for more information.\n\n    Note:\n        The presto protocol also includes some statistics on the query\n        in the response.\n        These statistics are currently only filled with placeholder variables.\n\n    Args:\n        context (:obj:`dask_sql.Context`): If set, use this context instead of an empty one.\n        client (:obj:`dask.distributed.Client`): If set, use this dask client instead of a new one.\n        host (:obj:`str`): The host interface to listen on (defaults to all interfaces)\n        port (:obj:`int`): The port to listen on (defaults to 8080)\n        startup (:obj:`bool`): Whether to wait until Apache Calcite was loaded\n        log_level: (:obj:`str`): The log level of the server and dask-sql\n        blocking: (:obj:`bool`): If running in an environment with an event loop (e.g. a jupyter notebook),\n                do not block. The server can be stopped with `context.stop_server()` afterwards.\n        jdbc_metadata: (:obj:`bool`): If enabled create JDBC metadata tables using schemas and tables in\n                the current dask_sql context\n\n    Example:\n        It is possible to run an SQL server by using the CLI script ``dask-sql-server``\n        or by calling this function directly in your user code:\n\n        .. code-block:: python\n\n            from dask_sql import run_server\n\n            # Create your pre-filled context\n            c = Context()\n            ...\n\n            run_server(context=c)\n\n        After starting the server, it is possible to send queries to it, e.g. with the\n        `presto CLI <https://prestosql.io/docs/current/installation/cli.html>`_\n        or via sqlalchemy (e.g. using the `PyHive <https://github.com/dropbox/PyHive#sqlalchemy>`_ package):\n\n        .. code-block:: python\n\n            from sqlalchemy.engine import create_engine\n            engine = create_engine('presto://localhost:8080/')\n\n            import pandas as pd\n            pd.read_sql_query(\"SELECT 1 + 1\", con=engine)\n\n        Of course, it is also possible to call the usual ``CREATE TABLE``\n        commands.\n\n        If in a jupyter notebook, you should run the following code\n\n        .. code-block:: python\n\n            from dask_sql import Context\n\n            c = Context()\n            c.run_server(blocking=False)\n\n            ...\n\n            c.stop_server()\n\n        Note:\n            When running in a jupyter notebook without blocking,\n            it is not possible to access the SQL server from within the\n            notebook, e.g. using sqlalchemy.\n            Doing so will deadlock infinitely.\n\n    \"\"\"\n    if context is None:\n        context = Context()\n    _init_app(app, context=context, client=client)\n    if jdbc_metadata:\n        create_meta_data(context)\n\n    if startup:\n        app.c.sql(\"SELECT 1 + 1\").compute()\n\n    config = Config(app, host=host, port=port, log_level=log_level)\n    server = Server(config=config)\n\n    if blocking:\n        server.run()\n    else:\n        loop = asyncio.get_event_loop()\n        loop.create_task(server.serve())\n        context.sql_server = server\n\n\ndef main():  # pragma: no cover\n    \"\"\"\n    CLI version of the :func:`run_server` function.\n    \"\"\"\n    parser = ArgumentParser()\n    parser.add_argument(\n        \"--host\",\n        default=\"0.0.0.0\",\n        help=\"The host interface to listen on (defaults to all interfaces)\",\n    )\n    parser.add_argument(\n        \"--port\", default=8080, help=\"The port to listen on (defaults to 8080)\"\n    )\n    parser.add_argument(\n        \"--scheduler-address\",\n        default=None,\n        help=\"Connect to this dask scheduler if given\",\n    )\n    parser.add_argument(\n        \"--log-level\",\n        default=None,\n        help=\"Set the log level of the server. Defaults to info.\",\n        choices=uvicorn.config.LOG_LEVELS,\n    )\n    parser.add_argument(\n        \"--load-test-data\",\n        default=False,\n        action=\"store_true\",\n        help=\"Preload some test data.\",\n    )\n    parser.add_argument(\n        \"--startup\",\n        default=False,\n        action=\"store_true\",\n        help=\"Wait until Apache Calcite was properly loaded\",\n    )\n\n    args = parser.parse_args()\n\n    client = None\n    if args.scheduler_address:\n        client = dask.distributed.Client(args.scheduler_address)\n\n    context = Context()\n    if args.load_test_data:\n        df = dask.datasets.timeseries(freq=\"1d\").reset_index(drop=False)\n        context.create_table(\"timeseries\", df.persist())\n\n    run_server(\n        context=context,\n        client=client,\n        host=args.host,\n        port=args.port,\n        startup=args.startup,\n        log_level=args.log_level,\n    )\n\n\ndef _init_app(\n    app: FastAPI,\n    context: Context = None,\n    client: dask.distributed.Client = None,\n):\n    app.c = context\n    app.future_list = {}\n\n    try:\n        client = client or dask.distributed.Client.current()\n    except ValueError:\n        client = dask.distributed.Client()\n    app.client = client\n"
  },
  {
    "path": "dask_sql/server/presto_jdbc.py",
    "content": "import logging\n\nimport pandas as pd\n\nfrom dask_sql.context import Context\n\nlogger = logging.getLogger(__name__)\n\n\ndef create_meta_data(c: Context):\n    \"\"\"\n    Creates the schema, table and column data for prestodb JDBC driver so that data can be viewed\n    in a database tool like DBeaver. It doesn't create a catalog entry although JDBC expects one\n    as dask-sql doesn't support catalogs. For both catalogs and procedures empty placeholder\n    tables are created.\n\n    The meta-data appears in a separate schema called system_jdbc largely because the JDBC driver\n    tries to access system.jdbc and it sufficiently so shouldn't clash with other schemas.\n\n    A function is required in the /v1/statement to change system.jdbc to system_jdbc and ignore\n    order by statements from the driver (as adjust_for_presto_sql above)\n\n    :param c: Context containing created tables\n    :return:\n    \"\"\"\n\n    if c is None:\n        logger.warning(\"Context None: jdbc meta data not created\")\n        return\n    catalog = \"\"\n    system_schema = \"system_jdbc\"\n    c.create_schema(system_schema)\n\n    # TODO: add support for catalogs in presto interface\n    # see https://github.com/dask-contrib/dask-sql/pull/351\n    # if catalog and len(catalog.strip()) > 0:\n    #     catalogs = pd.DataFrame().append(create_catalog_row(catalog), ignore_index=True)\n    #     c.create_table(\"catalogs\", catalogs, schema_name=system_schema)\n\n    schemas = pd.DataFrame(create_schema_row(), index=[0])\n    c.create_table(\"schemas\", schemas, schema_name=system_schema)\n    schema_rows = []\n\n    tables = pd.DataFrame(create_table_row(), index=[0])\n    c.create_table(\"tables\", tables, schema_name=system_schema)\n    table_rows = []\n\n    columns = pd.DataFrame(create_column_row(), index=[0])\n    c.create_table(\"columns\", columns, schema_name=system_schema)\n    column_rows = []\n\n    for schema_name, schema in c.schema.items():\n        schema_rows.append(create_schema_row(catalog, schema_name))\n        for table_name, dc in schema.tables.items():\n            df = dc.df\n            logger.info(f\"schema ${schema_name}, table {table_name}, {df}\")\n            table_rows.append(create_table_row(catalog, schema_name, table_name))\n            pos: int = 0\n            for column in df.columns:\n                pos = pos + 1\n                logger.debug(f\"column {column}\")\n                dtype = \"VARCHAR\"\n                if df[column].dtype == \"int64\" or df[column].dtype == \"int\":\n                    dtype = \"INTEGER\"\n                elif df[column].dtype == \"float64\" or df[column].dtype == \"float\":\n                    dtype = \"FLOAT\"\n                elif (\n                    df[column].dtype == \"datetime\"\n                    or df[column].dtype == \"datetime64[ns]\"\n                ):\n                    dtype = \"TIMESTAMP\"\n                column_rows.append(\n                    create_column_row(\n                        catalog,\n                        schema_name,\n                        table_name,\n                        dtype,\n                        df[column].name,\n                        str(pos),\n                    )\n                )\n\n    schemas = pd.DataFrame(schema_rows)\n    c.create_table(\"schemas\", schemas, schema_name=system_schema)\n    tables = pd.DataFrame(table_rows)\n    c.create_table(\"tables\", tables, schema_name=system_schema)\n    columns = pd.DataFrame(column_rows)\n    c.create_table(\"columns\", columns, schema_name=system_schema)\n\n    logger.info(f\"jdbc meta data ready for {len(table_rows)} tables\")\n\n\ndef create_catalog_row(catalog: str = \"\"):\n    return {\"TABLE_CAT\": catalog}\n\n\ndef create_schema_row(catalog: str = \"\", schema: str = \"\"):\n    return {\"TABLE_CATALOG\": catalog, \"TABLE_SCHEM\": schema}\n\n\ndef create_table_row(catalog: str = \"\", schema: str = \"\", table: str = \"\"):\n    return {\n        \"TABLE_CAT\": catalog,\n        \"TABLE_SCHEM\": schema,\n        \"TABLE_NAME\": table,\n        \"TABLE_TYPE\": \"\",\n        \"REMARKS\": \"\",\n        \"TYPE_CAT\": \"\",\n        \"TYPE_SCHEM\": \"\",\n        \"TYPE_NAME\": \"\",\n        \"SELF_REFERENCING_COL_NAME\": \"\",\n        \"REF_GENERATION\": \"\",\n    }\n\n\ndef create_column_row(\n    catalog: str = \"\",\n    schema: str = \"\",\n    table: str = \"\",\n    dtype: str = \"\",\n    column: str = \"\",\n    pos: str = \"\",\n):\n    return {\n        \"TABLE_CAT\": catalog,\n        \"TABLE_SCHEM\": schema,\n        \"TABLE_NAME\": table,\n        \"COLUMN_NAME\": column,\n        \"DATA_TYPE\": dtype,\n        \"TYPE_NAME\": dtype,\n        \"COLUMN_SIZE\": \"\",\n        \"BUFFER_LENGTH\": \"\",\n        \"DECIMAL_DIGITS\": \"\",\n        \"NUM_PREC_RADIX\": \"\",\n        \"NULLABLE\": \"\",\n        \"REMARKS\": \"\",\n        \"COLUMN_DEF\": \"\",\n        \"SQL_DATA_TYPE\": dtype,\n        \"SQL_DATETIME_SUB\": \"\",\n        \"CHAR_OCTET_LENGTH\": \"\",\n        \"ORDINAL_POSITION\": pos,\n        \"IS_NULLABLE\": \"\",\n        \"SCOPE_CATALOG\": \"\",\n        \"SCOPE_SCHEMA\": \"\",\n        \"SCOPE_TABLE\": \"\",\n        \"SOURCE_DATA_TYPE\": \"\",\n        \"IS_AUTOINCREMENT\": \"\",\n        \"IS_GENERATEDCOLUMN\": \"\",\n    }\n"
  },
  {
    "path": "dask_sql/server/responses.py",
    "content": "import uuid\n\nimport dask.dataframe as dd\nimport numpy as np\nimport pandas as pd\nfrom fastapi import Request\n\nfrom dask_sql.mappings import python_to_sql_type\n\n\nclass StageStats:\n    def __init__(self):\n        self.stageId = \"\"\n        self.state = \"\"\n        self.done = True\n        self.nodes = 0\n        self.totalSplits = 0\n        self.queuedSplits = 0\n        self.runningSplits = 0\n        self.completedSplits = 0\n        self.cpuTimeMillis = 0\n        self.wallTimeMillis = 0\n        self.processedRows = 0\n        self.processedBytes = 0\n        self.subStages = []\n\n\nclass StatementStats:\n    def __init__(self):\n        self.state = \"\"\n        self.queued = False\n        self.scheduled = False\n        self.nodes = 0\n        self.totalSplits = 0\n        self.queuedSplits = 0\n        self.runningSplits = 0\n        self.completedSplits = 0\n        self.cpuTimeMillis = 0\n        self.wallTimeMillis = 0\n        self.queuedTimeMillis = 0\n        self.elapsedTimeMillis = 0\n        self.processedRows = 0\n        self.processedBytes = 0\n        self.peakMemoryBytes = 0\n        self.peakTotalMemoryBytes = 0\n        self.peakTaskTotalMemoryBytes = 0\n        self.spilledBytes = 0\n        self.rootStage = StageStats()\n\n\nclass QueryResults:\n    def __init__(self, request: Request, next_url: str = None, cancel_url: str = None):\n        empty_url = str(request.url.replace(path=request.app.url_path_for(\"empty\")))\n\n        self.id = str(uuid.uuid4())\n        self.infoUri = empty_url\n        if next_url:\n            self.nextUri = next_url\n        if cancel_url:\n            self.partialCancelUri = cancel_url\n        self.stats = StatementStats()\n        self.warnings = []\n\n\nclass DataResults(QueryResults):\n    @staticmethod\n    def get_column_description(df):\n        sql_types = [str(python_to_sql_type(t)).lower() for t in df.dtypes]\n        column_names = df.columns\n        return [\n            {\n                \"name\": column_name,\n                \"type\": sql_type,\n                \"typeSignature\": {\n                    \"rawType\": sql_type,\n                    \"arguments\": []\n                    if sql_type not in (\"char\", \"varchar\")\n                    else [{\"kind\": \"LONG\", \"value\": 10}],\n                },\n            }\n            for column_name, sql_type in zip(column_names, sql_types)\n        ]\n\n    @staticmethod\n    def get_data_description(df):\n        if hasattr(df, \"to_pandas\"):\n            df = df.to_pandas()\n        return [\n            DataResults.convert_row(row)\n            for row in df.itertuples(index=False, name=None)\n        ]\n\n    @staticmethod\n    def convert_cell(cell):\n        try:\n            if pd.isna(cell):\n                return None\n            elif np.isnan(cell):  # pragma: no cover\n                return \"NaN\"\n            elif np.isposinf(cell):\n                return \"+Infinity\"\n            elif np.isneginf(cell):  # pragma: no cover\n                return \"-Infinity\"\n        except TypeError:  # pragma: no cover\n            pass\n\n        try:\n            return cell.item()\n        except AttributeError:\n            pass\n\n        return cell\n\n    @staticmethod\n    def convert_row(row):\n        return [DataResults.convert_cell(cell) for cell in row]\n\n    def __init__(self, df: dd.DataFrame, request: Request):\n        super().__init__(request)\n\n        if df is None:\n            return\n\n        self.columns = self.get_column_description(df)\n        self.data = self.get_data_description(df)\n\n\nclass ErrorResults(QueryResults):\n    def __init__(self, error: Exception, request: Request):\n        super().__init__(request)\n\n        self.error = QueryError(error)\n\n\nclass QueryError:\n    def __init__(self, error: Exception):\n        self.message = str(error)\n        self.errorCode = 0\n        self.errorName = str(type(error))\n        self.errorType = \"USER_ERROR\"\n\n        # FIXME: ParserErrors currently don't contain information on where the syntax error occurred\n        # try:\n        #     self.errorLocation = {\n        #         \"lineNumber\": error.from_line + 1,\n        #         \"columnNumber\": error.from_col + 1,\n        #     }\n        # except AttributeError:  # pragma: no cover\n        #     pass\n"
  },
  {
    "path": "dask_sql/sql-schema.yaml",
    "content": "properties:\n\n  sql:\n    type: object\n    properties:\n\n      aggregate:\n        type: object\n        properties:\n\n          split_out:\n            type: integer\n            description: |\n              Number of output partitions from an aggregation operation\n\n          split_every:\n            type: [integer, \"null\"]\n            description: |\n              Number of branches per reduction step from an aggregation operation.\n\n      identifier:\n        type: object\n        properties:\n\n          case_sensitive:\n            type: boolean\n            description: |\n              Whether sql identifiers are considered case sensitive while parsing.\n\n      join:\n        type: object\n        properties:\n\n          broadcast:\n            type: [boolean, number, \"null\"]\n            description: |\n              If boolean, it determines whether all joins should use the broadcast join algorithm.\n              If float, it's a value denoting dask's likelihood of selecting a broadcast join based\n              codepath over a shuffle based join. Concretely, dask will select a broadcast based join\n              algorithm if small_table.npartitions < log2(big_table.npartitions) * broadcast_bias\n              Note: Forcing a broadcast join might lead to perf issues or OOM errors in cases where the\n              broadcasted table is too large to fit on a single worker.\n\n      limit:\n        type: object\n        properties:\n\n          check-first-partition:\n            type: boolean\n            description: |\n              Whether or not to check the first partition length when computing a LIMIT without an OFFSET\n              on a table with a relatively simple Dask graph (i.e. only IO and/or partition-wise layers);\n              checking partition length triggers a Dask graph computation which can be slow for complex\n              queries, but can signicantly reduce memory usage when querying a small subset of a large\n              table. Default is ``true``.\n\n      optimize:\n        type: boolean\n        description: |\n          Whether the first generated logical plan should be further optimized or used as is.\n\n      predicate_pushdown:\n        type: boolean\n        description: |\n          Whether to try pushing down filter predicates into IO (when possible).\n\n      dynamic_partition_pruning:\n        type: boolean\n        description: |\n          Whether to apply the dynamic partition pruning optimizer rule.\n\n      optimizer:\n        type: object\n        properties:\n\n          verbose:\n            type: boolean\n            description: |\n              The dynamic partition pruning optimizer rule can sometimes result in extremely long\n              c.explain() outputs which are not helpful to the user. Setting this option to true\n              allows the user to see the entire output, while setting it to false truncates the\n              output. Default is false.\n\n      fact_dimension_ratio:\n        type: [number, \"null\"]\n        description: |\n          Ratio of the size of the dimension tables to fact tables. Parameter for dynamic partition\n          pruning and join reorder optimizer rules.\n\n      max_fact_tables:\n        type: [integer, \"null\"]\n        description: |\n          Maximum number of fact tables to allow in a join. Parameter for join reorder optimizer\n          rule.\n\n      preserve_user_order:\n        type: [boolean, \"null\"]\n        description: |\n          Whether to preserve user-defined order of unfiltered dimensions. Parameter for join\n          reorder optimizer rule.\n\n      filter_selectivity:\n        type: [number, \"null\"]\n        description: |\n          Constant to use when determining the number of rows produced by a filtered relation.\n          Parameter for join reorder optimizer rule.\n\n      sort:\n        type: object\n        properties:\n\n          topk-nelem-limit:\n            type: integer\n            description: |\n              Total number of elements below which dask-sql should attempt to apply the top-k\n              optimization (when possible). ``nelem`` is defined as the limit or ``k`` value times the\n              number of columns. Default is 1000000, corresponding to a LIMIT clause of 1 million in a\n              1 column table.\n\n      mappings:\n        type: object\n        properties:\n\n          decimal_support:\n            type: string\n            description:\n              Decides how to handle decimal scalars/columns. ``\"pandas\"`` handling will treat decimals scalars and columns as floats and float64 columns, respectively, while ``\"cudf\"`` handling treats decimal scalars as ``decimal.Decimal`` objects and decimal columns as ``cudf.Decimal128Dtype`` columns, handling precision/scale accordingly. Default is ``\"pandas\"``, but ``\"cudf\"`` should be used if attempting to work with decimal columns on GPU.\n"
  },
  {
    "path": "dask_sql/sql.yaml",
    "content": "sql:\n  aggregate:\n    split_out: 1\n    split_every: null\n\n  identifier:\n    case_sensitive: True\n\n  join:\n    broadcast: null\n\n  limit:\n    check-first-partition: True\n\n  optimize: True\n\n  predicate_pushdown: True\n\n  dynamic_partition_pruning: True\n\n  optimizer:\n    verbose: False\n\n  fact_dimension_ratio: null\n\n  max_fact_tables: null\n\n  preserve_user_order: null\n\n  filter_selectivity: null\n\n  sort:\n    topk-nelem-limit: 1000000\n\n  mappings:\n    decimal_support: \"pandas\"\n"
  },
  {
    "path": "dask_sql/utils.py",
    "content": "import importlib\nimport logging\nfrom collections import defaultdict\nfrom datetime import datetime\nfrom typing import Any\nfrom uuid import uuid4\n\nimport dask.dataframe as dd\nimport numpy as np\nimport pandas as pd\n\nfrom dask_sql._datafusion_lib import SqlTypeName\nfrom dask_sql.datacontainer import DataContainer\nfrom dask_sql.mappings import sql_to_python_value\n\nlogger = logging.getLogger(__name__)\n\n\ndef is_frame(df):\n    \"\"\"\n    Check if something is a dataframe (and not a scalar or none)\n    \"\"\"\n    return (\n        df is not None\n        and not np.isscalar(df)\n        and not isinstance(df, type(pd.NA))\n        and not isinstance(df, datetime)\n    )\n\n\ndef is_datetime(obj):\n    \"\"\"\n    Check if a scalar or a series is of datetime type\n    \"\"\"\n    return pd.api.types.is_datetime64_any_dtype(obj) or isinstance(obj, datetime)\n\n\ndef convert_to_datetime(df):\n    \"\"\"\n    Covert a scalar or a series to datetime type\n    \"\"\"\n    if is_frame(df):\n        df = df.dt\n    else:\n        df = pd.to_datetime(df)\n    return df\n\n\ndef is_cudf_type(obj):\n    \"\"\"\n    Check if an object is a cuDF type\n    \"\"\"\n    types = [\n        str(type(obj)),\n        str(getattr(obj, \"_partition_type\", \"\")),\n        str(getattr(obj, \"_meta\", \"\")),\n    ]\n    return any(\"cudf\" in obj_type for obj_type in types)\n\n\nclass Pluggable:\n    \"\"\"\n    Helper class for everything which can be extended by plugins.\n    Basically just a mapping of a name to the stored plugin\n    for ever class.\n    Please note that the plugins are stored\n    in this single class, which makes simple extensions possible.\n    \"\"\"\n\n    __plugins = defaultdict(dict)\n\n    @classmethod\n    def add_plugin(cls, names, plugin, replace=True):\n        \"\"\"Add a plugin with the given name\"\"\"\n        if isinstance(names, str):\n            names = [names]\n\n        if not replace and all(name in Pluggable.__plugins[cls] for name in names):\n            return\n\n        Pluggable.__plugins[cls].update({name: plugin for name in names})\n\n    @classmethod\n    def get_plugin(cls, name):\n        \"\"\"Get a plugin with the given name\"\"\"\n        return Pluggable.__plugins[cls][name]\n\n    @classmethod\n    def get_plugins(cls):\n        \"\"\"Return all registered plugins\"\"\"\n        return list(Pluggable.__plugins[cls].values())\n\n\nclass ParsingException(Exception):\n    \"\"\"\n    Helper class to format validation and parsing SQL\n    exception in a nicer way\n    \"\"\"\n\n    def __init__(self, sql, validation_exception_string):\n        \"\"\"\n        Create a new exception out of the SQL query and the exception text\n        raise by calcite.\n        \"\"\"\n        super().__init__(validation_exception_string.strip())\n\n\nclass OptimizationException(Exception):\n    \"\"\"\n    Helper class for formatting exceptions that occur while trying to\n    optimize a logical plan\n    \"\"\"\n\n    def __init__(self, exception_string):\n        \"\"\"\n        Create a new exception out of the SQL query and the exception from DataFusion\n        \"\"\"\n        super().__init__(exception_string.strip())\n\n\nclass LoggableDataFrame:\n    \"\"\"Small helper class to print resulting dataframes or series in logging messages\"\"\"\n\n    def __init__(self, df):\n        self.df = df\n\n    def __str__(self):\n        df = self.df\n        if isinstance(df, pd.Series) or isinstance(df, dd.Series):\n            return f\"Series: {(df.name, df.dtype)}\"\n        if isinstance(df, pd.DataFrame) or isinstance(df, dd.DataFrame):\n            return f\"DataFrame: {[(col, dtype) for col, dtype in zip(df.columns, df.dtypes)]}\"\n\n        elif isinstance(df, DataContainer):\n            cols = df.column_container.columns\n            dtypes = {col: dtype for col, dtype in zip(df.df.columns, df.df.dtypes)}\n            mapping = df.column_container.get_backend_by_frontend_index\n            dtypes = [dtypes[mapping(index)] for index in range(len(cols))]\n            return f\"DataFrame: {[(col, dtype) for col, dtype in zip(cols, dtypes)]}\"\n\n        return f\"Literal: {df}\"\n\n\ndef convert_sql_kwargs(\n    sql_kwargs: dict[str, str],\n) -> dict[str, Any]:\n    \"\"\"\n    Convert the Rust Vec of key/value pairs into a Dict containing the keys and values\n    \"\"\"\n\n    def convert_literal(value):\n        if value.isCollection():\n            operator_mapping = {\n                \"SqlTypeName.ARRAY\": list,\n                \"SqlTypeName.MAP\": lambda x: dict(zip(x[::2], x[1::2])),\n                \"SqlTypeName.MULTISET\": set,\n                \"SqlTypeName.ROW\": tuple,\n            }\n\n            operator = operator_mapping[str(value.getSqlType())]\n            operands = [convert_literal(o) for o in value.getOperandList()]\n\n            return operator(operands)\n        elif value.isKwargs():\n            return convert_sql_kwargs(value.getKwargs())\n        else:\n            literal_type = value.getSqlType()\n            literal_value = value.getSqlValue()\n\n            if literal_type == SqlTypeName.VARCHAR:\n                return value.getSqlValue()\n            elif literal_type == SqlTypeName.BIGINT and \".\" in literal_value:\n                literal_type = SqlTypeName.DOUBLE\n\n            python_value = sql_to_python_value(literal_type, literal_value)\n            return python_value\n\n    return {key: convert_literal(value) for key, value in dict(sql_kwargs).items()}\n\n\ndef import_class(name: str) -> type:\n    \"\"\"\n    Import a class with the given name by loading the module\n    and referencing the class in the module\n    \"\"\"\n    module_path, class_name = name.rsplit(\".\", 1)\n    module = importlib.import_module(module_path)\n    return getattr(module, class_name)\n\n\ndef new_temporary_column(df: dd.DataFrame) -> str:\n    \"\"\"Return a new column name which is currently not in use\"\"\"\n    while True:\n        col_name = str(uuid4())\n\n        if col_name not in df.columns:\n            return col_name\n        else:  # pragma: no cover\n            continue\n"
  },
  {
    "path": "docs/Makefile",
    "content": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line, and also\n# from the environment for the first two.\nSPHINXOPTS    ?=\nSPHINXBUILD   ?= sphinx-build\nSOURCEDIR     = source\nBUILDDIR      = build\n\n# Put it first so that \"make\" without argument is like \"make help\".\nhelp:\n\t@$(SPHINXBUILD) -M help \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\n.PHONY: help Makefile\n\n# Catch-all target: route all unknown targets to Sphinx using the new\n# \"make mode\" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).\n%: Makefile\n\t@$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n"
  },
  {
    "path": "docs/environment.yml",
    "content": "name: dask-sql-docs\nchannels:\n  - conda-forge\ndependencies:\n  - python=3.9\n  - sphinx>=4.0.0\n  - sphinx-tabs\n  - dask-sphinx-theme>=2.0.3\n  - dask>=2024.4.1\n  - pandas>=1.4.0\n  - fugue>=0.7.3\n  # FIXME: https://github.com/fugue-project/fugue/issues/526\n  - triad<0.9.2\n  - fastapi>=0.92.0\n  - httpx>=0.24.1\n  - uvicorn>=0.14\n  - tzlocal>=2.1\n  - prompt_toolkit>=3.0.8\n  - pygments>=2.7.1\n  - tabulate\n  - ucx-proc=*=cpu\n  - rust=1.72\n"
  },
  {
    "path": "docs/make.bat",
    "content": "@ECHO OFF\r\n\r\npushd %~dp0\r\n\r\nREM Command file for Sphinx documentation\r\n\r\nif \"%SPHINXBUILD%\" == \"\" (\r\n\tset SPHINXBUILD=sphinx-build\r\n)\r\nset SOURCEDIR=source\r\nset BUILDDIR=build\r\n\r\nif \"%1\" == \"\" goto help\r\n\r\n%SPHINXBUILD% >NUL 2>NUL\r\nif errorlevel 9009 (\r\n\techo.\r\n\techo.The 'sphinx-build' command was not found. Make sure you have Sphinx\r\n\techo.installed, then set the SPHINXBUILD environment variable to point\r\n\techo.to the full path of the 'sphinx-build' executable. Alternatively you\r\n\techo.may add the Sphinx directory to PATH.\r\n\techo.\r\n\techo.If you don't have Sphinx installed, grab it from\r\n\techo.https://www.sphinx-doc.org/\r\n\texit /b 1\r\n)\r\n\r\n%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%\r\ngoto end\r\n\r\n:help\r\n%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%\r\n\r\n:end\r\npopd\r\n"
  },
  {
    "path": "docs/requirements-docs.txt",
    "content": "sphinx>=4.0.0\nsphinx-tabs\ndask-sphinx-theme>=3.0.0\ndask>=2024.4.1\npandas>=1.4.0\nfugue>=0.7.3\n# FIXME: https://github.com/fugue-project/fugue/issues/526\ntriad<0.9.2\nfastapi>=0.92.0\nhttpx>=0.24.1\nuvicorn>=0.14\ntzlocal>=2.1\nprompt_toolkit>=3.0.8\npygments>=2.7.1\ntabulate\nmaturin>=1.3,<1.4\n"
  },
  {
    "path": "docs/source/api.rst",
    "content": ".. _api:\n\nAPI Documentation\n=================\n\n.. autoclass:: dask_sql.Context\n   :members:\n   :undoc-members:\n\n.. autofunction:: dask_sql.run_server\n\n.. autofunction:: dask_sql.cmd_loop\n\n.. autoclass:: dask_sql.integrations.fugue.DaskSQLExecutionEngine\n   :members:\n\n.. autofunction:: dask_sql.integrations.fugue.fsql_dask\n"
  },
  {
    "path": "docs/source/best_practices.rst",
    "content": ".. _best_practices:\n\nBest Practices and Performance Tips\n===================================\n\nSort and Use Read Filtering\n---------------------------\n\nIf you often read by key ranges or perform lots of logic with groups of related records, you should consider using Dask Dataframe's `shuffle <https://docs.dask.org/en/stable/generated/dask.dataframe.DataFrame.shuffle.html>`_.\nThis operation ensures that all rows of a given key will be within a single partition.\nThis is helpful for querying records on a specific key or keys such as customer IDs or session keys, as it allows Dask to skip partitions based on the partition min and max values thus avoiding reading each record.\nThis can save a large amount of IO time and is especially helpful when using a network file system.\n\nFor example, querying a specific pickup time from a taxi dataset ends up returning a result with over 200 partitions as each of these partitions needs to be checked for that key.\n\n.. code-block:: python\n\n    ddf = dd.read_parquet('/data/taxi_pq_2GB', split_row_groups=False)\n    c.create_table('taxi_unsorted', ddf)\n    c.sql(\"select * from taxi_unsorted where DAYOFMONTH(pickup_datetime) = 15\").npartitions\n\n.. code-block::\n\n    244\n\nBut, if you were to instead sort by the pickup time and use the ``DISTRIBUTE BY`` operation, which is equivalent to Dask Dataframe's shuffle, you can reduce the number of partitions in the result to 1.\n\n.. code-block:: python\n\n    def intra_partition_sort(df, sort_keys):\n        return df.sort_values(sort_keys)\n\n    c.sql(\"\"\"\n    SELECT\n        DAYOFMONTH(pickup_datetime) AS dom,\n        HOUR(pickup_datetime) AS hr,\n        *\n    FROM\n        taxi_unsorted\n    DISTRIBUTE BY dom\n    \"\"\").map_partitions(intra_partition_sort, ['dom', 'hr']).to_parquet('/data/taxi_sorted')\n\n.. code-block:: python\n\n    sorted_ddf = dd.read_parquet(\n        '/data/taxi_sorted',\n        split_row_groups=False,\n        filters=[\n            [(\"dom\", \"==\", 15)]\n        ]\n    )\n\n    c.create_table(\"taxi_sorted\", sorted_ddf)\n    c.sql(\"SELECT * FROM taxi_sorted WHERE dom = 15\").npartitions\n\n.. code-block::\n\n    1\n\nThis comes with a large corresponding boost in computation speed. For example,\n\n.. code-block:: python\n\n    %%time\n    c.sql(\"SELECT COUNT(*) FROM taxi_unsorted WHERE DAYOFMONTH(pickup_datetime) = 15\").compute()\n\n.. code-block::\n\n    CPU times: user 2.4 s, sys: 275 ms, total: 2.68 s\n    Wall time: 2.58 s\n\n.. code-block:: python\n\n    %%time\n    c.sql(\"SELECT COUNT(*) FROM taxi_sorted WHERE dom = 15\").compute()\n\n.. code-block::\n\n    CPU times: user 318 ms, sys: 21.7 ms, total: 340 ms\n    Wall time: 274 ms\n\n\nFor a deeper dive into read filtering with Dask, check out `Filtered Reading with RAPIDS & Dask to Optimize ETL <https://medium.com/rapids-ai/filtered-reading-with-rapids-dask-to-optimize-etl-5f1624f4be55>`_.\n\nIn many cases Dask-SQL can automate sorting and read filtering with its predicate pushdown support.\n\nFor example, the query\n\n.. code-block:: sql\n\n    SELECT\n        COUNT(*)\n    FROM\n        taxi\n    WHERE\n        DAYOFMONTH(pickup_datetime) = 15\n\nwould automatically perform the same sorting and read filtering logic as the previous section.\n\nAvoid Unnecessary Parallelism\n-----------------------------\n\nAdditionally, more tasks added to the Dask graph means more overhead added by the scheduler which can be\na major performance inhibitor at large scales.\n\nFor CPUs this isn't as much of an issue, as CPUs tend to have allowance for more workers and CPU tasks tend to take longer, so the additional overhead is relatively less impactful.\nBut, for GPUs there's typically only one worker per GPU and tasks tend to be shorter, so the overhead added by a large number of tasks can greatly affect performance.\n\nImprove performance by only creating tasks as necessary. For example, splitting row groups creates more tasks so avoid this if possible.\n\n.. code-block:: python\n\n    weather_dir = '/data/weather_pq_2GB/*.parquet'\n\n\n.. code-block:: sql\n\n    CREATE OR REPLACE TABLE weather_split WITH (\n        location = '{weather_dir}',\n        gpu=True,\n        split_row_groups=True\n    )\n\n.. code-block:: sql\n\n    SELECT COUNT(*) FROM weather_split WHERE type='PRCP'\n\n\n.. code-block:: sql\n\n    CREATE OR REPLACE TABLE weather_nosplit WITH (\n        location = '{weather_dir}',\n        gpu=True,\n        split_row_groups=False\n    )\n\n.. code-block:: sql\n\n    SELECT COUNT(*) FROM weather_nosplit WHERE type='PRCP'\n\nUse broadcast joins when possible\n---------------------------------\n\nJoins and grouped aggregations typically require communication between workers, which can be expensive.\nBroadcast joins can help reduce this communication in the case of joining a small table to a large table by just sending the small table to each partition of the large table.\nHowever, in Dask-SQL this only works when the small table is a single partition.\n\nFor example, if you read in some tables and concatenate them with a ``UNION ALL`` operation\n\n.. code-block:: sql\n\n    CREATE OR REPLACE TABLE precip AS\n    SELECT\n        station_id,\n        substring(\"date\", 0, 4) as yr,\n        substring(\"date\", 5, 2) as mth,\n        substring(\"date\", 7, 2) as dy,\n        val*1/10*0.0393701 as inches\n    FROM weather_nosplit\n    WHERE type='PRCP'\n\n.. code-block:: sql\n\n    CREATE OR REPLACE TABLE atlanta_stations WITH (\n        location = '/data/atlanta_stations/*.parquet',\n        gpu=True\n    )\n\n.. code-block:: sql\n\n    CREATE OR REPLACE TABLE seattle_stations WITH (\n        location = '/data/seattle_stations/*.parquet',\n        gpu=True\n    )\n\n\n.. code-block:: sql\n\n    CREATE OR REPLACE TABLE city_stations AS\n    SELECT * FROM atlanta_stations\n    UNION ALL\n    SELECT * FROM seattle_stations\n\nyou get a new table that has two partitions. Then if you use it in a join\n\n.. code-block:: sql\n\n    SELECT\n        yr,\n        city,\n        CASE WHEN city='Atlanta' THEN\n            sum(inches)/{atl_stations}\n        ELSE\n            sum(inches)/{seat_stations}\n        END AS inches\n    FROM precip\n    JOIN city_stations\n    ON precip.station_id = city_stations.station_id\n    GROUP BY yr, city\n    ORDER BY yr ASC\n\nDask-SQL won't perform a broadcast join and will instead perform a traditional join with a corresponding slow compute time.\nHowever, if you were to repartition the smaller table to a single partition and rerun the operation\n\n.. code-block:: python\n\n    c.create_table(\"city_stations\", c.sql(\"select * from city_stations\").repartition(npartitions=1))\n\n.. code-block:: sql\n\n    SELECT\n        yr,\n        city,\n        CASE WHEN city='Atlanta' THEN\n            sum(inches)/{atl_stations}\n        ELSE\n            sum(inches)/{seat_stations}\n        END AS inches\n    FROM precip\n    JOIN city_stations\n    ON precip.station_id = city_stations.station_id\n    GROUP BY yr, city\n    ORDER BY yr ASC\n\nDask-SQL is able to recognize this as a broadcast join and the result is a significantly faster compute time.\n\nDask-SQL also supports biasing the heuristic Dask uses to determine whether to use a broadcast join through the ``sql.join.broadcast`` config option.\nThis option passes either a boolean or a float value to the ``broadcast`` argument in Dask's `merge <https://docs.dask.org/en/stable/generated/dask.dataframe.multi.DataFrame.merge.html?highlight=broadcast_join#dask.dataframe.multi.DataFrame.merge>`_ function.\nIn the case of passing a float, a larger value makes Dask more likely to use a broadcast join.\n\nFor example,\n\n.. code-block:: python\n\n    c.sql(query, config_options={\"sql.join.broadcast\": True})\n\nwould instruct Dask to always use a broadcast join if supported for the query whereas\n\n.. code-block:: python\n\n    c.sql(query, config_options={\"sql.join.broadcast\": 0.7})\n\nwould instruct Dask to use ``0.7`` as the ``broadcast_bias`` in its heuristic for deciding whether to use a broadcast join.\n\nOptimize Partition Sizes for GPUs\n---------------------------------\nFile formats like `Apache ORC <https://orc.apache.org/>`_ and `Apache Parquet <https://parquet.apache.org/>`_ are designed so that they can be pulled from disk and be deserialized by CPUs quickly.\nHowever, loading data into GPUs has a substantial additional cost in the form of transfers from CPU to GPU memory.\nMinimizing that cost is often achieved by increasing partition size.\nEven when using Dask-SQL on GPUs, upstream CPU systems will likely produce small files resulting in small partitions.\nIt's worth taking the time to repartition to larger partition sizes before querying the files on GPUs, especially when querying the same files multiple times.\n\nThere's no single optimal size so choose a size that's tuned for your workflow.\nOperations like joins and concatenations greatly increase GPU memory utilization, even if temporarily, but if you're not performing many of these operations, the larger the partition size the better.\nLarger partition sizes increase disk to GPU throughput and keep GPU utilization higher for faster runtimes.\n\nWe recommend a starting point of around 2gb uncompressed data per partition for GPUs.\nIt's usually not necessary to change from default settings when running Dask-SQL on CPUs, but if you want to manually set partition sizes, we've found 128-256mb per partition to be a good starting place.\n"
  },
  {
    "path": "docs/source/cmd.rst",
    "content": ".. _cmd:\n\nCommand Line Tool\n=================\n\nIt is also possible to run a small CLI tool for testing out some\nSQL commands quickly.\n\nYou can either call the CLI tool (after installation) directly\n\n.. code-block:: bash\n\n    dask-sql\n\nor by running these lines of code\n\n.. code-block:: python\n\n    from dask_sql import cmd_loop\n\n    cmd_loop()\n\nSome options can be set, e.g. to preload some testdata.\nHave a look into :func:`~dask_sql.cmd_loop` or call\n\n.. code-block:: bash\n\n    dask-sql --help\n\nOf course, it is also possible to call the usual ``CREATE TABLE``\ncommands.\n\nVery similar as described in :ref:`server`, it is possible to preregister your own data sources\nor choose a dask scheduler to connect to.\n"
  },
  {
    "path": "docs/source/conf.py",
    "content": "# Configuration file for the Sphinx documentation builder.\n#\n# This file only contains a selection of the most common options. For a full\n# list see the documentation:\n# https://www.sphinx-doc.org/en/master/usage/configuration.html\n# contents of docs/conf.py\n# If extensions (or modules to document with autodoc) are in another directory,\n# add these directories to sys.path here. If the directory is relative to the\n# documentation root, use os.path.abspath to make it absolute, like shown here.\n#\nimport os\nimport sys\nfrom datetime import datetime\n\n# -- Path setup --------------------------------------------------------------\n\n\nsys.path.insert(0, os.path.abspath(\"..\"))\n\n\n# -- Project information -----------------------------------------------------\n\nproject = \"dask-sql\"\ncopyright = f\"{datetime.today().year}, Nils Braun\"\nauthor = \"Nils Braun\"\n\n# -- General configuration ---------------------------------------------------\n\n# Add any Sphinx extension module names here, as strings. They can be\n# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom\n# ones.\nextensions = [\n    \"sphinx.ext.autodoc\",\n    \"sphinx.ext.napoleon\",\n    \"sphinx_tabs.tabs\",\n    \"dask_sphinx_theme.ext.dask_config_sphinx_ext\",\n]\n\n# Add any paths that contain templates here, relative to this directory.\ntemplates_path = [\"_templates\"]\n\n# List of patterns, relative to source directory, that match files and\n# directories to ignore when looking for source files.\n# This pattern also affects html_static_path and html_extra_path.\nexclude_patterns = [\"_build\", \"Thumbs.db\", \".DS_Store\"]\n\n\n# -- Options for HTML output -------------------------------------------------\n\n# The theme to use for HTML and HTML Help pages.  See the documentation for\n# a list of builtin themes.\n#\nhtml_theme = \"dask_sphinx_theme\"\n\n# Add any paths that contain custom static files (such as style sheets) here,\n# relative to this directory. They are copied after the builtin static files,\n# so a file named \"default.css\" will overwrite the builtin \"default.css\".\n# html_static_path = [\"_static\"]\n\n# Make sure to reference the correct master document\nmaster_doc = \"index\"\n\n# Do not show type mappings\nautodoc_typehints = \"none\"\n\n# disable collapsible tabs\nsphinx_tabs_disable_tab_closing = True\n"
  },
  {
    "path": "docs/source/configuration.rst",
    "content": ".. _configuration:\n\nConfiguration in Dask-SQL\n==========================\n\n``dask-sql`` supports a list of configuration options to configure behavior of certain operations.\n``dask-sql`` uses `Dask's config <https://docs.dask.org/en/stable/configuration.html>`_\nmodule and configuration options can be specified with YAML files, via environment variables,\nor directly, either through the `dask.config.set <https://docs.dask.org/en/stable/configuration.html#dask.config.set>`_ method\nor the ``config_options`` argument in the :func:`dask_sql.Context.sql` method.\n\nConfiguration Reference\n-----------------------\n\n.. dask-config-block::\n    :location: sql\n    :config: https://raw.githubusercontent.com/dask-contrib/dask-sql/main/dask_sql/sql.yaml\n    :schema: https://raw.githubusercontent.com/dask-contrib/dask-sql/main/dask_sql/sql-schema.yaml\n"
  },
  {
    "path": "docs/source/custom.rst",
    "content": ".. _custom:\n\nCustom Functions and Aggregations\n=================================\n\nAdditional to the included SQL functionalities, it is possible to include custom functions and aggregations into the SQL queries of ``dask-sql``.\nThe custom functions are classified into scalar functions and aggregations.\nIf you want to combine Machine Learning with SQL, you might also be interested in :ref:`machine_learning`.\n\nScalar Functions\n----------------\n\nA scalar function (such as :math:`x \\to x^2`) turns a given column into another column of the same length.\nIt can be registered for usage in SQL with the :func:`~dask_sql.Context.register_function` method.\n\nExample:\n\n.. code-block:: python\n\n    def f(x):\n        return x ** 2\n\n    c.register_function(f, \"f\", [(\"x\", np.int64)], np.int64)\n\nThe registration gives a name to the function and also adds type information on the input types and names, as well as the return type.\nAll usual numpy types (e.g. ``np.int64``) and pandas types (``Int64``) are supported.\n\nAfter registration, the function can be used as any other usual SQL function:\n\n.. code-block:: python\n\n    c.sql(\"SELECT f(column) FROM data\")\n\nScalar functions can have one or more input parameters and can combine columns and literal values.\n\nRow-Wise Pandas UDFs\n--------------------\nIn some cases it may be easier to write custom functions which process a dict like row object, such as those consumed by ``pandas.DataFrame.apply``.\nThese functions may be registered as above and flagged as row UDFs using the `row_udf` keyword argument:\n\n.. code-block:: python\n\n    def f(row):\n        return row['a'] + row['b']\n\n    c.register_function(f, \"f\", [(\"a\", np.int64), (\"b\", np.int64)], np.int64, row_udf=True)\n    c.sql(\"SELECT f(a, b) FROM data\")\n\n** Note: Row UDFs use `apply` which may have unpredictable performance characteristics, depending on the function and dataframe library **\n\nUDFs written in this way can also be extended to accept scalar arguments along with the incoming row:\n\n.. code-block:: python\n\n    def f(row, k):\n        return row['a'] + k\n\n    c.register_function(f, \"f\", [(\"a\", np.int64), (\"k\", np.int64)], np.int64, row_udf=True)\n    c.sql(\"SELECT f(a, 42) FROM data\")\n\n\nAggregation Functions\n---------------------\n\nAggregation functions run on a single column and turn them into a single value.\nThis means they can only be used in ``GROUP BY`` aggregations.\nThey can be registered with the :func:`~dask_sql.Context.register_aggregation` method.\nThis time however, an instance of a :class:`dask.dataframe.Aggregation` needs to be passed\ninstead of a plain function.\nMore information on dask aggregations can be found in the\n`dask documentation <https://docs.dask.org/en/latest/dataframe-groupby.html#aggregate>`_.\n\nExample:\n\n.. code-block:: python\n\n    my_sum = dd.Aggregation(\"my_sum\", lambda x: x.sum(), lambda x: x.sum())\n    c.register_aggregation(my_sum, \"my_sum\", [(\"x\", np.float64)], np.float64)\n\n    c.sql(\"SELECT my_sum(other_colum) FROM df GROUP BY column\")\n\n.. note::\n\n    There can only ever exist a single function with the same name.\n    No matter if this is an aggregation function or a scalar function.\n"
  },
  {
    "path": "docs/source/data_input.rst",
    "content": ".. _data_input:\n\nData Loading and Input\n======================\n\nBefore data can be queried with ``dask-sql``, it needs to be loaded into the Dask cluster (or local instance) and registered with the :class:`~dask_sql.Context`.\n``dask-sql`` supports all ``dask``-compatible `input formats  <https://docs.dask.org/en/latest/dataframe-create.html>`_, plus some additional formats only suitable for ``dask-sql``.\n\n1. Load it via Python\n---------------------\n\nYou can either use already created Dask DataFrames or create one by using the :func:`~dask_sql.Context.create_table` function.\nChances are high, there exists already a function to load your favorite format or location (e.g. S3 or hdfs).\nSee below for all formats understood by ``dask-sql``.\nMake sure to install required libraries both on the driver and worker machines:\n\n.. tabs::\n\n  .. group-tab:: CPU\n\n    .. code-block:: python\n\n        import dask.dataframe as dd\n        from dask_sql import Context\n\n        c = Context()\n        df = dd.read_csv(\"s3://nyc-tlc/trip data/yellow_tripdata_2019-01.csv\")\n\n        c.create_table(\"my_data\", df)\n\n  .. group-tab:: GPU\n\n    .. code-block:: python\n\n        import dask.dataframe as dd\n        from dask_sql import Context\n\n        c = Context()\n        df = dd.read_csv(\"s3://nyc-tlc/trip data/yellow_tripdata_2019-01.csv\")\n\n        c.create_table(\"my_data\", df, gpu=True)\n\nor in short (equivalent):\n\n.. tabs::\n\n  .. group-tab:: CPU\n\n    .. code-block:: python\n\n        from dask_sql import Context\n\n        c = Context()\n\n        c.create_table(\"my_data\", \"s3://nyc-tlc/trip data/yellow_tripdata_2019-01.csv\")\n\n  .. group-tab:: GPU\n\n    .. code-block:: python\n\n        from dask_sql import Context\n\n        c = Context()\n\n        c.create_table(\"my_data\", \"s3://nyc-tlc/trip data/yellow_tripdata_2019-01.csv\", gpu=True)\n\n2. Load it via SQL\n------------------\n\nIf you are connected to the SQL server implementation or you do not want to issue Python command calls, you can also\nachieve the data loading via SQL only.\n\n.. tabs::\n\n  .. group-tab:: CPU\n\n    .. code-block:: sql\n\n        CREATE TABLE my_data WITH (\n            format = 'csv',\n            location = 's3://nyc-tlc/trip data/yellow_tripdata_2019-01.csv'\n        )\n\n  .. group-tab:: GPU\n\n    .. code-block:: sql\n\n        CREATE TABLE my_data WITH (\n            format = 'csv',\n            location = 's3://nyc-tlc/trip data/yellow_tripdata_2019-01.csv',\n            gpu = True\n        )\n\nThe parameters are the same as in the Python function described above.\nYou can find more information in :ref:`creation`.\n\n3. Persist and share data on the cluster\n----------------------------------------\n\nIn ``dask``, you can publish datasets with names into the cluster memory.\nThis allows to reuse the same data from multiple clients/users in multiple sessions.\n\nFor example, you can publish your data using the ``client.publish_dataset`` function of the ``distributed.Client``,\nand then later register it in the :class:`~dask_sql.Context` via SQL:\n\n.. code-block:: python\n\n    # a dask.distributed Client\n    client = Client(...)\n    client.publish_dataset(my_df=df)\n\nLater in SQL:\n\n.. tabs::\n\n  .. group-tab:: CPU\n\n    .. code-block:: SQL\n\n        CREATE TABLE my_data WITH (\n            format = 'memory',\n            location = 'my_df'\n        )\n\n  .. group-tab:: GPU\n\n    .. code-block:: SQL\n\n        CREATE TABLE my_data WITH (\n            format = 'memory',\n            location = 'my_df',\n            gpu = True\n        )\n\nNote, that the format is set to ``memory`` and the location is the name, which was chosen when publishing the dataset.\n\nTo achieve the same thing from Python, you can just use Dask's methods to get the dataset\n\n.. tabs::\n\n  .. group-tab:: CPU\n\n    .. code-block:: python\n\n        df = client.get_dataset(\"my_df\")\n        c.create_table(\"my_data\", df)\n\n  .. group-tab:: GPU\n\n    .. code-block:: python\n\n        df = client.get_dataset(\"my_df\")\n        c.create_table(\"my_data\", df, gpu=True)\n\n\nInput Formats\n-------------\n\n``dask-sql`` understands (thanks to the large Dask ecosystem) a wide verity of input formats and input locations.\n\n* All formats and locations mentioned in `the Dask documentation  <https://docs.dask.org/en/latest/dataframe-create.html>`_, including CSV, Parquet, and JSON.\n  Just pass in the location as string (and possibly the format, e.g. \"csv\" if it is not clear from the file extension).\n  The data can be from local disc or many remote locations (S3, hdfs, Azure Filesystem, http, Google Filesystem, ...) - just prefix the path with the matching protocol.\n  Additional arguments passed to :func:`~dask_sql.Context.create_table` or ``CREATE TABLE`` are given to the ``read_<format>`` calls.\n\nExample:\n\n.. tabs::\n\n  .. group-tab:: CPU\n\n    .. code-block:: python\n\n      c.create_table(\n          \"my_data\",\n          \"s3://bucket-name/my-data-*.csv\",\n          storage_options={'anon': True}\n      )\n\n    .. code-block:: sql\n\n      CREATE TABLE my_data WITH (\n          format = 'csv', -- can also be omitted, as clear from the extension\n          location = 's3://bucket-name/my-data-*.csv',\n          storage_options = (\n              anon = True\n          )\n      )\n\n  .. group-tab:: GPU\n\n    .. code-block:: python\n\n      c.create_table(\n          \"my_data\",\n          \"s3://bucket-name/my-data-*.csv\",\n          gpu=True,\n          storage_options={'anon': True}\n      )\n\n    .. code-block:: sql\n\n      CREATE TABLE my_data WITH (\n          format = 'csv', -- can also be omitted, as clear from the extension\n          location = 's3://bucket-name/my-data-*.csv',\n          gpu = True,\n          storage_options = (\n              anon = True\n          )\n      )\n\n* If your data is already in Pandas (or Dask) DataFrames format, you can just use it as it is via the Python API\n  by giving it to :func:`~dask_sql.Context.create_table` directly.\n* You can connect ``dask-sql`` to an `intake <https://intake.readthedocs.io/en/latest/index.html>`_ catalog and\n  use the data registered there. Assuming you have an intake catalog stored in \"catalog.yaml\" (can also be\n  the URL of an intake server), you can read in a stored table \"data_table\" either via Python\n\n  .. code-block:: python\n\n    catalog = intake.open_catalog(\"catalog.yaml\")\n    c.create_table(\"my_data\", catalog, intake_table_name=\"intake_table\")\n    # or\n    c.create_table(\"my_data\", \"catalog.yaml\", format=\"intake\", intake_table_name=\"intake_table\")\n\n  or via SQL:\n\n  .. code-block:: sql\n\n    CREATE TABLE my_data WITH (\n        format = 'intake',\n        location = 'catalog.yaml'\n    )\n\n  The argument ``intake_table_name`` is optional and defaults to the table name in ``dask_sql``.\n  With the argument ``catalog_kwargs`` you can control how the intake catalog object is created.\n  Additional arguments are forwarded to the ``to_dask()`` call of intake.\n* As an experimental feature, it is also possible to use data stored in the `Apache Hive <https://hive.apache.org/>`_\n  metastore. For this, ``dask-sql`` will retrieve the information on the storage location and format\n  from the metastore and will then register the raw data directly in the context.\n  This means, no Hive data query will be issued and you might be able to see a speed improvement.\n\n  It is both possible to use a `pyhive.hive.Cursor` or an `sqlalchemy` connection.\n\n  .. code-block:: python\n\n    from dask_sql import Context\n    from pyhive.hive import connect\n    import sqlalchemy\n\n    c = Context()\n\n    cursor = connect(\"hive-server\", 10000).cursor()\n    # or\n    cursor = sqlalchemy.create_engine(\"hive://hive-server:10000\").connect()\n\n    c.create_table(\"my_data\", cursor, hive_table_name=\"the_name_in_hive\")\n\n  or in SQL:\n\n  .. code-block:: sql\n\n    CREATE TABLE my_data WITH (\n        location = 'hive://hive-server:10000',\n        hive_table_name = 'the_name_in_hive'\n    )\n\n  Again, ``hive_table_name`` is optional and defaults to the table name in ``dask-sql``.\n  You can also control the database used in Hive via the ``hive_schema_name`` parameter.\n  Additional arguments are pushed to the internally called ``read_<format>`` functions.\n* Similarly, it is possible to load data from a `Databricks Cluster <https://docs.databricks.com/clusters/index.html>`_ (which is similar to a Hive metastore).\n\n  You need to have the ``databricks-dbapi`` package installed and ``fsspec >= 0.8.7``.\n  A token needs to be `generated <https://docs.databricks.com/dev-tools/api/latest/authentication.html>`_ for the accessing user.\n  The ``host``, ``port`` and ``http_path`` information can be found in the JDBC tab of the cluster.\n\n  .. code-block:: python\n\n    from dask_sql import Context\n    from sqlalchemy import create_engine\n\n    c = Context()\n\n    cursor = create_engine(f\"databricks+pyhive://token:{token}@{host}:{port}/\",\n                           connect_args={\"http_path\": http_path}).connect()\n\n    c.create_table(\"my_data\", cursor, hive_table_name=\"schema.table\",\n                   storage_options={\"instance\": host, \"token\": token})\n\n  or in SQL\n\n  .. code-block:: sql\n\n    CREATE TABLE my_data WITH (\n        location = 'databricks+pyhive://token:{token}@{host}:{port}/',\n        connect_args = (\n            http_path = '{http_path}'\n        ),\n        hive_table_name = 'schema.table',\n        storage_options = (\n            instance = '{host}',\n            token = '{token}'\n        )\n    )\n\n.. note::\n    For ``dask-sql`` it does not matter how you load your data.\n    In all shown cases you can then use the specified table name to query your data\n    in a ``SELECT`` call.\n\n    Please note however that un-persisted data will be reread from its source (e.g. on S3 or disk)\n    on every query whereas persisted data is only read once.\n    This will increase the query speed, but will also prevent you from seeing external updates to your\n    data (until you reload it explicitly).\n"
  },
  {
    "path": "docs/source/fugue.rst",
    "content": "FugueSQL Integrations\n=====================\n\n`FugueSQL <https://fugue-tutorials.readthedocs.io/tutorials/fugue_sql/index.html>`_ is a related project that aims to provide a unified SQL interface for a variety of different computing frameworks, including Dask.\nWhile it offers a SQL engine with a larger set of supported commands, this comes at the cost of slower performance when using Dask in comparison to dask-sql.\nIn order to offer a \"best of both worlds\" solution, dask-sql includes several options to integrate with FugueSQL, using its faster implementation of SQL commands when possible and falling back on FugueSQL when necessary.\n\ndask-sql as a FugueSQL engine\n-----------------------------\n\nFugueSQL users unfamiliar with dask-sql can take advantage of its functionality by installing it in an environment alongside Fugue; this will automatically register :class:`dask_sql.integrations.fugue.DaskSQLExecutionEngine` as the default Dask execution engine for FugueSQL queries.\nFor more information and sample usage, see `Fugue — dask-sql as a FugueSQL engine <https://fugue-tutorials.readthedocs.io/tutorials/integrations/dasksql.html>`_.\n\nUsing FugueSQL on an existing ``Context``\n-----------------------------------------\n\ndask-sql users attempting to expand their SQL querying options for an existing ``Context`` can use :func:`dask_sql.integrations.fugue.fsql_dask`, which executes the provided query using FugueSQL, using the tables within the provided context as input.\nThe results of this query can then optionally be registered to the context:\n\n.. code-block:: python\n\n    # define a custom prepartition function for FugueSQL\n    def median(df: pd.DataFrame) -> pd.DataFrame:\n        df[\"y\"] = df[\"y\"].median()\n        return df.head(1)\n\n    # create a context with some tables\n    c = Context()\n    ...\n\n    # run a FugueSQL query using the context as input\n    query = \"\"\"\n        j = SELECT df1.*, df2.x\n            FROM df1 INNER JOIN df2 ON df1.key = df2.key\n            PERSIST\n        TAKE 5 ROWS PREPARTITION BY x PRESORT key\n        PRINT\n        TRANSFORM j PREPARTITION BY x USING median\n        PRINT\n        \"\"\"\n    result = fsql_dask(query, c, register=True)  # results aren't registered by default\n\n    assert \"j\" in result    # returns a dict of resulting tables\n    assert \"j\" in c.tables  # results are also registered to the context\n"
  },
  {
    "path": "docs/source/how_does_it_work.rst",
    "content": "How does it work?\n=================\n\nAt the core, ``dask-sql`` does two things:\n\n- Translates the SQL query using `Apache Arrow DataFusion <https://arrow.apache.org/datafusion/>`_ into a relational algebra,\n  represented by a `LogicalPlan enum <https://docs.rs/datafusion-expr/latest/datafusion_expr/enum.LogicalPlan.html>`_ - similar\n  to many other SQL engines (Hive, Flink, ...)\n- Converts this description of the query from the Rust enum into Dask API calls (and executes them) - returning a Dask dataframe.\n\nThe following example explains this in quite some technical details.\nFor most of the users, this level of technical understanding is not needed.\n\n1. SQL enters the library\n-------------------------\n\nNo matter of via the Python API (:ref:`api`), the command line client (:ref:`cmd`) or the server (:ref:`server`), eventually the SQL statement by the user will end up as a string in the function :func:`~dask_sql.Context.sql`.\n\n2. SQL is parsed\n----------------\n\nThis function will first give the SQL string to the dask_planner Rust crate via the ``PyO3`` library.\nInside this crate, Apache Arrow DataFusion is used to first parse the SQL string and then turn it into a relational algebra.\nFor this, DataFusion uses the SQL language description specified in the `sqlparser-rs library <https://github.com/sqlparser-rs/sqlparser-rs/>`_\nWe also include `SQL extensions specific to Dask-SQL <https://github.com/dask-contrib/dask-sql/blob/main/src/parser.rs/>`_. They specify custom language features, such as the ``CREATE MODEL`` statement.\n\n3. SQL is (maybe) optimized\n---------------------------\n\nOnce the SQL string is parsed into a :class:`Statement` enum, DataFusion can convert it into a relational algebra represented by a `LogicalPlan enum <https://docs.rs/datafusion-expr/latest/datafusion_expr/enum.LogicalPlan.html>`_\nand optimize it. As this is only implemented for DataFusion supported syntax (and not for the custom syntax such\nas :class:`SqlCreateModel`) this conversion and optimization is not triggered for all SQL statements (have a look\ninto :func:`Context._get_ral`).\n\nThe logical plan is a tree structure and most enum variants (such as :class:`Projection` or :class:`Join`) can contain\nother instances as \"inputs\" creating a tree of different steps in the SQL statement (see below for an example).\n\nThe result is an optimized :class:`LogicalPlan`.\n\n4. Translation to Dask API calls\n--------------------------------\n\nEach step in the :class:`LogicalPlan` is converted into calls to Python functions using different Python \"converters\".\nFor each enum variant (such as :class:`Projection` and :class:`Join`), there exist a converter class in\nthe ``dask_sql.physical.rel`` folder, which are registered at the :class:`dask_sql.physical.rel.convert.RelConverter` class.\n\nTheir job is to use the information stored in the logical plan enum variants and turn it into calls to Python functions (see the example below for more information).\n\nAs many SQL statements contain calculations using literals and/or columns, these are split into their own functionality (``dask_sql.physical.rex``) following a similar plugin-based converter system.\nHave a look into the specific classes to understand how the conversion of a specific SQL language feature is implemented.\n\n5. Result\n---------\n\nThe result of each of the conversions is a :class:`dask.DataFrame`, which is given to the user. In case of the command line tool or the SQL server, it is evaluated immediately - otherwise it can be used for further calculations by the user.\n\nExample\n-------\n\nLet's walk through the steps above using the example SQL statement\n\n.. code-block:: sql\n\n    SELECT x + y FROM timeseries WHERE x > 0\n\nassuming the table \"timeseries\" is already registered.\nIf you want to follow along with the steps outlined in the following, start the command line tool in debug mode\n\n.. code-block:: bash\n\n    dask-sql --load-test-data --startup --log-level DEBUG\n\nand enter the SQL statement above.\n\nFirst, the SQL is parsed by DataFusion and (as it is not a custom statement) transformed into a tree of relational algebra objects.\n\n.. code-block:: none\n\n    Projection: #timeseries.x + #timeseries.y\n      Filter: #timeseries.x > Float64(0)\n        TableScan: timeseries projection=[x, y]\n\nThe tree output above means, that the outer instance (:class:`Projection`) needs as input the output of the previous instance (:class:`Filter`) etc.\n\nTherefore the conversion to Python API calls is called recursively (depth-first). First, the :class:`LogicalTableScan` is converted using the :class:`rel.logical.table_scan.LogicalTableScanPlugin` plugin. It will just get the correct :class:`dask.DataFrame` from the dictionary of already registered tables of the context.\nNext, the :class:`LogicalFilter` (having the dataframe as input), is converted via the :class:`rel.logical.filter.LogicalFilterPlugin`.\nThe filter expression ``>($3, 0)`` is converted into ``df[\"x\"] > 0`` using a combination of REX plugins (have a look into the debug output to learn more) and applied to the dataframe.\nThe resulting dataframe is then passed to the converter :class:`rel.logical.project.LogicalProjectPlugin` for the :class:`LogicalProject`.\nThis will calculate the expression ``df[\"x\"] + df[\"y\"]`` (after having converted it via the class:`RexCallPlugin` plugin) and return the final result to the user.\n\n.. code-block:: python\n\n    df_table_scan = context.tables[\"timeseries\"]\n    df_filter = df_table_scan[df_table_scan[\"x\"] > 0]\n    df_project = df_filter.assign(col=df_filter[\"x\"] + df_filter[\"y\"])\n    return df_project[[\"col\"]]\n"
  },
  {
    "path": "docs/source/index.rst",
    "content": "dask-sql\n========\n\n``dask-sql`` is a distributed SQL query engine in Python.\nIt allows you to query and transform your data using a mixture of\ncommon SQL operations and Python code and also scale up the calculation easily\nif you need it.\n\n* **Combine the power of Python and SQL**: load your data with Python, transform it with SQL, enhance it with Python and query it with SQL - or the other way round.\n  With ``dask-sql`` you can mix the well known Python dataframe API of `pandas` and ``Dask`` with common SQL operations, to\n  process your data in exactly the way that is easiest for you.\n* **Infinite Scaling**: using the power of the great ``Dask`` ecosystem, your computations can scale as you need it - from your laptop to your super cluster - without changing any line of SQL code. From k8s to cloud deployments, from batch systems to YARN - if ``Dask`` `supports it <https://docs.dask.org/en/latest/setup.html>`_, so will ``dask-sql``.\n* **Your data - your queries**: Use Python user-defined functions (UDFs) in SQL without any performance drawback and extend your SQL queries with the large number of Python libraries, e.g. machine learning, different complicated input formats, complex statistics.\n* **Easy to install and maintain**: ``dask-sql`` is just a pip/conda install away (or a docker run if you prefer).\n* **Use SQL from wherever you like**: ``dask-sql`` integrates with your jupyter notebook, your normal Python module or can be used as a standalone SQL server from any BI tool. It even integrates natively with `Apache Hue <https://gethue.com/>`_.\n* **GPU Support**: ``dask-sql`` has support for running SQL queries on CUDA-enabled GPUs by utilizing `RAPIDS <https://rapids.ai>`_ libraries like `cuDF <https://github.com/rapidsai/cudf>`_ , enabling accelerated compute for SQL.\n\n\nExample\n-------\n\nFor this example, we use some data loaded from disk and query it with a SQL command.\n``dask-sql`` accepts any pandas, cuDF, or dask dataframe as input and is able to read data directly from a variety of storage formats (CSV, Parquet, JSON) and file systems (S3, hdfs, gcs):\n\n.. tabs::\n\n   .. group-tab:: CPU\n\n      .. code-block:: python\n\n         import dask.datasets\n         from dask_sql import Context\n\n         # create a context to register tables\n         c = Context()\n\n         # create a table and register it in the context\n         df = dask.datasets.timeseries()\n         c.create_table(\"timeseries\", df)\n\n         # execute a SQL query; the result is a \"lazy\" Dask dataframe\n         result = c.sql(\"\"\"\n            SELECT\n               name, SUM(x) as \"sum\"\n            FROM\n               timeseries\n            GROUP BY\n               name\n         \"\"\")\n\n         # actually compute the query...\n         result.compute()\n\n         # ...or use it for another computation\n         result[\"sum\"].mean().compute()\n\n   .. group-tab:: GPU\n\n      .. code-block:: python\n\n         import dask.datasets\n         from dask_sql import Context\n\n         # create a context to register tables\n         c = Context()\n\n         # create a table and register it in the context\n         df = dask.datasets.timeseries()\n         c.create_table(\"timeseries\", df, gpu=True)\n\n         # execute a SQL query; the result is a \"lazy\" Dask dataframe\n         result = c.sql(\"\"\"\n            SELECT\n               name, SUM(x) as \"sum\"\n            FROM\n               timeseries\n            GROUP BY\n               name\n         \"\"\")\n\n         # actually compute the query...\n         result.compute()\n\n         # ...or use it for another computation\n         result[\"sum\"].mean().compute()\n\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Contents:\n\n   installation\n   quickstart\n   sql\n   data_input\n   custom\n   machine_learning\n   best_practices\n   api\n   server\n   cmd\n   fugue\n   how_does_it_work\n   configuration\n\n\n.. note::\n\n   ``dask-sql`` is currently under development and does so far not understand all SQL commands.\n   We are actively looking for feedback, improvements and contributors!\n"
  },
  {
    "path": "docs/source/installation.rst",
    "content": ".. _installation:\n\nInstallation\n============\n\n``dask-sql`` can be installed via ``conda`` (preferred) or ``pip`` - or in a development environment.\n\nYou can continue with the :ref:`quickstart` after the installation.\n\nWith ``conda``\n--------------\n\nCreate a new conda environment or use your already present environment:\n\n.. code-block:: bash\n\n    conda create -n dask-sql\n    conda activate dask-sql\n\nInstall the package from the ``conda-forge`` channel:\n\n.. code-block:: bash\n\n    conda install dask-sql -c conda-forge\n\nGPU support\n^^^^^^^^^^^\n\n- GPU support is currently tied to the `RAPIDS <https://rapids.ai/>`_  libraries.\n- It generally requires the latest `cuDF/Dask-cuDF <https://docs.rapids.ai/api/cudf/nightly/user_guide/10min.html>`_ nightlies.\n\nCreate a new conda environment or use an existing one to install RAPIDS with the chosen methods and packages.\nMore details can be found on the `RAPIDS Getting Started <https://rapids.ai/start.html>`_ page, but as an example:\n\n.. code-block:: bash\n\n    conda create --name rapids-env -c rapidsai-nightly -c nvidia -c conda-forge \\\n        cudf=22.10 dask-cudf=22.10 ucx-py ucx-proc=*=gpu python=3.9 cudatoolkit=11.8\n    conda activate rapids-env\n\nNote that using UCX is mainly necessary if you have an Infiniband or NVLink enabled system.\nRefer to the `UCX-Py docs <https://ucx-py.readthedocs.io/en/latest/>`_ for more information.\n\nInstall the stable package from the ``conda-forge`` channel:\n\n.. code-block:: bash\n\n    conda install -c conda-forge dask-sql\n\nOr the latest nightly from the ``dask`` channel (currently only available for Linux-based operating systems):\n\n.. code-block:: bash\n\n    conda install -c dask/label/dev dask-sql\n\n\nWith ``pip``\n------------\n\n.. code-block:: bash\n\n    pip install dask-sql\n\nFor development\n---------------\n\nIf you want to have the newest (unreleased) ``dask-sql`` version or if you plan to do development on ``dask-sql``, you can also install the package from sources.\n\n.. code-block:: bash\n\n    git clone https://github.com/dask-contrib/dask-sql.git\n\nCreate a new conda environment and install the development environment:\n\n.. code-block:: bash\n\n    conda env create -f continuous_integration/environment-3.9.yaml\n\nIt is not recommended to use ``pip`` instead of ``conda``.\n\nAfter that, you can install the package in development mode\n\n.. code-block:: bash\n\n    pip install -e \".[dev]\"\n\nTo compile the Rust code (after changes), the above command must be rerun.\nYou can run the tests (after installation) with\n\n.. code-block:: bash\n\n    pytest tests\n\nGPU-specific tests require additional dependencies specified in `continuous_integration/gpuci/environment.yaml`:\n\n.. code-block:: bash\n\n    conda env create -n dask-sql-gpuci -f continuous_integration/gpuci/environment.yaml\n\nGPU-specific tests can be run with\n\n.. code-block:: bash\n\n    pytest tests -m gpu --rungpu\n\nThis repository uses pre-commit hooks. To install them, call\n\n.. code-block:: bash\n\n    pre-commit install\n"
  },
  {
    "path": "docs/source/machine_learning.rst",
    "content": ".. _machine_learning:\n\nMachine Learning\n================\n\n.. note::\n    Machine Learning support is experimental in ``dask-sql``.\n    We encourage you to try it out and report any issues on our\n    `issue tracker <https://github.com/dask-contrib/dask-sql/issues>`_.\n\nBoth the training as well as the prediction using Machine Learning methods play a crucial role in\nmany data analytics applications. ``dask-sql`` supports Machine Learning\napplications in different ways, depending on how much you would like to do in Python or SQL.\n\nPlease also see :ref:`ml` for more information on the SQL statements used on this page.\n\n1. Data Preparation in SQL, Training and Prediction in Python\n-------------------------------------------------------------\n\nIf you are familiar with Python and the ML ecosystem in Python, this one is probably\nthe simplest possibility. You can use the :func:`~dask_sql.Context.sql` call as described\nbefore to extract the data for your training or ML prediction.\nThe result will be a Dask dataframe, which you can either directly feed into your model\nor convert to a pandas dataframe with `.compute()` before.\n\nThis gives you full control on the training process and the simplicity of\nusing SQL for data manipulation. You can use this method in your Python scripts\nor Jupyter Notebooks, but not from the :ref:`server` or :ref:`cmd`.\n\n2. Training in Python, Prediction in SQL\n----------------------------------------\n\nIn many companies/teams, it is typical that some team members are responsible for\ncreating/training a ML model, and others use it to predict unseen data.\nIt would be possible to create a custom function (see :ref:`custom`) to load and use the model,\nwhich then can be used in ``SELECT`` queries.\nHowever for convenience, ``dask-sql`` introduces a SQL keyword to do this work for you\nautomatically. The syntax is similar to the `BigQuery Predict Syntax <https://cloud.google.com/bigquery-ml/docs/reference/standard-sql/bigqueryml-syntax-predict>`_.\n\n.. code-block:: python\n\n    c.sql(\"\"\"\n    SELECT * FROM PREDICT ( MODEL my_model,\n        SELECT x, y, z FROM data\n    )\n    \"\"\")\n\nThis call will first collect the data from the inner ``SELECT`` call (which can be any valid\n``SELECT`` call, including ``JOIN``, ``WHERE``, ``GROUP BY``, custom tables and views etc.)\nand will then apply the model with the name \"my_model\" for prediction.\nThe model needs to be registered at the context before using :func:`~dask_sql.Context.register_model`.\n\n.. code-block:: python\n\n    c.register_model(\"my_model\", model)\n\nThe model registered here can be any valid python object, which follows the scikit-learn\ninterface, which is to have a ``predict()`` function.\nPlease note that the input will not be pandas dataframe, but a Dask dataframe.\nSee :ref:`ml` for more information.\n\n3. Training and Prediction in SQL\n---------------------------------\n\nThis method, in contrast to the other two possibilities, works completely from SQL,\nwhich allows you to also call it e.g. from your BI tool.\nAdditionally to the ``PREDICT`` keyword mentioned above, ``dask-sql`` also has a way to\ncreate and train a model from SQL:\n\n.. code-block:: sql\n\n    CREATE MODEL my_model WITH (\n        model_class = 'LogisticRegression',\n        wrap_predict = True,\n        target_column = 'target'\n    ) AS (\n        SELECT x, y, x*y > 0 as target\n        FROM timeseries\n        LIMIT 100\n    )\n\nThis call will create a new instance of ``sklearn.linear_model.LogisticRegression`` or ``cuml.linear_model.LogisticRegression``\n(the full path is inferred by Dask-SQL depending on whether you are using a CPU or GPU DataFrame)\nand train it with the data collected from the ``SELECT`` call (again, every valid ``SELECT``\nquery can be given). The model can than be used in subsequent calls to ``PREDICT``\nusing the given name.\nWe explicitly set ``wrap_predict`` = ``True`` here to parallelize post fit prediction task of non distributed models (sklearn/cuML etc) across workers,\nalthough in this case ``wrap_predict`` would have already defaulted to ``True`` for the sklearn model.\n\nHave a look into :ref:`ml` for more information.\n\n4. Check Model parameters - Model meta data\n-------------------------------------------\nAfter the model was trained, you can inspect and get model details by using the\nfollowing SQL statements\n\n.. code-block:: sql\n\n    -- show the list of models  which are trained and stored in the context.\n    SHOW MODELS\n\n    -- To get the hyperparameters of the trained MODEL, use\n    -- DESCRIBE MODEL <model_name>.\n    DESCRIBE MODEL my_model\n\n5. Hyperparameter Tuning\n-------------------------\nWant to increase the performance of your model by tuning the\nparameters? Use the hyperparameter tuning directly\nin SQL using below SQL syntax, choose different tuners\nbased on memory and compute constraints.\n\n..\n    TODO - add a GPU section to these examples once we have working CREATE EXPERIMENT tests for GPU\n\n.. code-block:: sql\n\n CREATE EXPERIMENT my_exp WITH (\n    model_class = 'GradientBoostingClassifier',\n    experiment_class = 'GridSearchCV',\n    tune_parameters = (n_estimators = ARRAY [16, 32, 2],\n                    learning_rate = ARRAY [0.1,0.01,0.001],\n                   max_depth = ARRAY [3,4,5,10]\n                   ),\n    experiment_kwargs = (n_jobs = -1),\n    target_column = 'target'\n    ) AS (\n        SELECT x, y, x*y > 0 AS target\n        FROM timeseries\n        LIMIT 100\n    )\n\nIn this case, we set ``n_jobs`` = ``-1`` to ensure that all jobs run in parallel.\n\n5.1 AutoML in SQL\n-----------------\nWant to try different models with different parameters in SQL? Now you can\nstart AutoML experiments with the help of the ``tpot`` framework, which trains\nand evaluates a number of different sklearn-compatible models and uses Dask for\ndistributing the work across the Dask clusters.\nUse below SQL syntax for AutoML and for more details refer to the\n`tpot automl framework <https://epistasislab.github.io/tpot/>`_\n\n\n.. code-block:: sql\n\n    CREATE EXPERIMENT my_exp WITH (\n            automl_class = 'tpot.TPOTClassifier',\n            automl_kwargs = (population_size = 2 ,\n            generations=2,\n            cv=2,\n            n_jobs=-1,\n            use_dask=True,\n            max_eval_time_mins=1),\n            target_column = 'target'\n            ) AS (\n                SELECT x, y, x*y > 0 AS target\n                FROM timeseries\n                LIMIT 100\n            )\n\nAfter the experiment was completed, both hyperparameter tuner and\nAutoML experiments stores the best model of the experiment in the SQL context with\nthe name same as the experiment name, which can be used for prediction.\n\n6. Export Trained Model\n------------------------\nOnce your model was trained and performs good in your validation dataset,\nyou can export the model into a file with one of the supported model serialization\nformats like Pickle, Joblib, MLflow (framework-agnostic serialization format), etc.\n\nCurrently, Dask-SQL supports the Pickle, Joblib and MLflow format for exporting the\ntrained model, which can then be deployed as microservices, etc.\n\nBefore training and exporting the models from different framework like\nLightGBM or CatBoost, please ensure the relevant packages are installed in the\nDask-SQL environment, otherwise it will raise an exception on import. If you\nare using MLflow, ensure MLflow is installed. Keep in mind that Dask-SQL supports\nonly sklearn-compatible models (i.e fit-predict style models) so far, so instead of using\n``xgb.core.Booster``, consider using ``xgboost.XGBClassifier`` since the latter is sklearn-compatible\nand used by Dask-SQL for training, predicting, and exporting the model\nthrough the standard sklearn interface.\n\n\n..\n    TODO - add a GPU section to these examples once we have working EXPORT MODEL tests for GPU\n\n.. code-block:: sql\n\n    -- for pickle model serialization\n    EXPORT MODEL my_model WITH (\n        format ='pickle',\n        location = 'model.pkl'\n    )\n\n    -- for joblib model serialization\n    EXPORT MODEL my_model WITH (\n        format ='joblib',\n        location = 'model.pkl'\n    )\n\n    -- for mlflow model serialization\n    EXPORT MODEL my_model WITH (\n        format ='mlflow',\n        location = 'mlflow_dir'\n    )\n\n    -- Note you can pass more number of key value pairs\n    -- (parameters) which will be delegated to the respective\n    -- export functions\n\n\nExample\n~~~~~~~\n\nThe following SQL-only code gives an example on how the commands can play together.\nWe assume that you have created/registered a table \"my_data\" with the numerical columns ``x`` and ``y``\nand the boolean target ``label``.\n\n..\n    TODO - add a GPU section to these examples once we have working CREATE EXPERIMENT tests for GPU\n\n.. code-block:: sql\n\n    -- First, we create a new feature z out of x and y.\n    -- For convenience, we store it in another table\n    CREATE OR REPLACE TABLE transformed_data AS (\n        SELECT x, y, x + y AS z, label\n        FROM my_data\n    )\n\n    -- We split the data into a training set\n    -- by using the first 100 items.\n    -- Please note that this is just for a very quick-and-dirty\n    -- example - you would probably want to do something\n    -- more advanced here, maybe with TABLESAMPLE\n    CREATE OR REPLACE TABLE training_data AS (\n        SELECT * FROM transformed_data\n        LIMIT 15\n    )\n\n    -- Quickly check the data\n    SELECT * FROM training_data\n\n    -- We can now train a model from the sklearn package.\n    CREATE OR REPLACE MODEL my_model WITH (\n        model_class = 'sklearn.ensemble.GradientBoostingClassifier',\n        wrap_predict = True,\n        target_column = 'label'\n    ) AS (\n        SELECT * FROM training_data\n    )\n\n    -- Now apply the trained model on all the data\n    -- and compare.\n    SELECT\n        *, (CASE WHEN target = label THEN True ELSE False END) AS correct\n    FROM PREDICT(MODEL my_model,\n        SELECT * FROM transformed_data\n    )\n    -- list models\n    SHOW MODELS\n    -- check parameters of the model\n    DESCRIBE MODEL my_model\n\n    -- experiment to tune different hyperparameters\n    CREATE EXPERIMENT my_exp WITH(\n    model_class = 'sklearn.ensemble.GradientBoostingClassifier',\n    experiment_class = 'sklearn.model_selection.GridSearchCV',\n    tune_parameters = (n_estimators = ARRAY [16, 32, 2],\n                    learning_rate = ARRAY [0.1,0.01,0.001],\n                   max_depth = ARRAY [3,4,5,10]\n                   ),\n    experiment_kwargs = (n_jobs = -1),\n    target_column = 'label'\n    ) AS (\n        SELECT * FROM training_data\n    )\n\n\n    -- creates experiment with automl framework\n    CREATE EXPERIMENT my_exp WITH (\n            automl_class = 'tpot.TPOTRegressor',\n            automl_kwargs = (population_size = 2 ,\n            generations=2,\n            cv=2,\n            n_jobs=-1,\n            use_dask=True,\n            max_eval_time_mins=1),\n            target_column = 'z'\n            ) AS (\n                SELECT * FROM training_data\n            )\n\n    -- checks the parameter of automl model\n    DESCRIBE MODEL automl_TPOTRegressor\n\n    -- export model\n    EXPORT MODEL my_model WITH (\n        format ='pickle',\n        location = 'model.pkl'\n    )\n"
  },
  {
    "path": "docs/source/quickstart.rst",
    "content": ".. _quickstart:\n\nQuickstart\n==========\n\nAfter :ref:`installation`, you can start querying your data using SQL.\n\nRun the following code in an interactive Python session, a Python script or a Jupyter Notebook.\n\n0. Cluster Setup\n----------------\n\nIf you just want to try out ``dask-sql`` quickly, this step can be skipped.\nHowever, the real magic of ``dask`` (and ``dask-sql``) comes from the ability to scale the computations over multiple cores and/or machines.\nFor local development and testing, a Distributed ``LocalCluster`` (or, if using GPUs, a `Dask-CUDA <https://docs.rapids.ai/api/dask-cuda/nightly/index.html>`_ ``LocalCUDACluster``) can be deployed and a client connected to it like so:\n\n.. tabs::\n\n    .. group-tab:: CPU\n\n        .. code-block:: python\n\n            from distributed import Client, LocalCluster\n\n            cluster = LocalCluster()\n            client = Client(cluster)\n\n    .. group-tab:: GPU\n\n        .. code-block:: python\n\n            from dask_cuda import LocalCUDACluster\n            from distributed import Client\n\n            cluster = LocalCUDACluster()\n            client = Client(cluster)\n\nThere are several options for deploying clusters depending on the platform being used and the resources available; see `Dask - Deploying Clusters <https://docs.dask.org/en/latest/deploying.html>`_ for more information.\n\n1. Data Loading\n---------------\n\nBefore querying the data, you need to create a ``dask`` `data frame <https://docs.dask.org/en/latest/dataframe.html>`_ containing the data.\n``dask`` understands many different `input formats <https://docs.dask.org/en/latest/dataframe-create.html>`_ and sources.\nIn this example, we do not read in external data, but use test data in the form of random event time series:\n\n.. code-block:: python\n\n    import dask.datasets\n\n    df = dask.datasets.timeseries()\n\nRead more on the data input part in :ref:`data_input`.\n\n2. Data Registration\n--------------------\n\nIf we want to work with the data in SQL, we need to give the data frame a unique name.\nWe do this by registering the data in an instance of a :class:`~dask_sql.Context`:\n\n.. tabs::\n\n    .. group-tab:: CPU\n\n        .. code-block:: python\n\n            from dask_sql import Context\n\n            c = Context()\n            c.create_table(\"timeseries\", df)\n\n    .. group-tab:: GPU\n\n        .. code-block:: python\n\n            from dask_sql import Context\n\n            c = Context()\n            c.create_table(\"timeseries\", df, gpu=True)\n\nFrom now on, the data is accessible as the ``timeseries`` table of this context.\nIt is possible to register multiple data frames in the same context.\n\n.. hint::\n    If you plan to query the same data multiple times,\n    it might make sense to persist the data before:\n\n    .. tabs::\n\n        .. group-tab:: CPU\n\n            .. code-block:: python\n\n                c.create_table(\"timeseries\", df, persist=True)\n\n        .. group-tab:: GPU\n\n            .. code-block:: python\n\n                c.create_table(\"timeseries\", df, persist=True, gpu=True)\n\n3. Run your queries\n-------------------\n\nNow you can go ahead and query the data with normal SQL!\n\n.. code-block:: python\n\n    result = c.sql(\"\"\"\n        SELECT\n            name, SUM(x) AS \"sum\"\n        FROM timeseries\n        WHERE x > 0.5\n        GROUP BY name\n    \"\"\")\n    result.compute()\n\n``dask-sql`` understands a large fraction of SQL commands, but there are still some missing.\nHave a look into the :ref:`sql` description for more information.\n\nIf you are using ``dask-sql`` from a Jupyter notebook, you might be interested in the ``sql`` magic function:\n\n.. code-block:: python\n\n    c.ipython_magic()\n\n    %%sql\n    SELECT\n        name, SUM(x) AS \"sum\"\n    FROM timeseries\n    WHERE x > 0.5\n    GROUP BY name\n\n.. note::\n    If you have found an SQL feature, which is currently not supported by ``dask-sql``,\n    please raise an issue on our `issue tracker <https://github.com/dask-contrib/dask-sql/issues>`_.\n"
  },
  {
    "path": "docs/source/server.rst",
    "content": ".. _server:\n\nSQL Server\n==========\n\n``dask-sql`` comes with a small test implementation for a SQL server.\nInstead of rebuilding a full ODBC driver, we re-use the `presto wire protocol <https://github.com/prestodb/presto/wiki/HTTP-Protocol>`_.\n\n.. note::\n\n    It is - so far - only a start of the development and missing important concepts, such as\n    authentication.\n\nYou can test the sql presto server by running (after installation)\n\n.. code-block:: bash\n\n    dask-sql-server\n\nor by running these lines of code\n\n.. code-block:: python\n\n    from dask_sql import run_server\n\n    run_server()\n\nor directly with a created context\n\n.. code-block:: python\n\n    c.run_server()\n\nor by using the created docker image\n\n.. code-block:: bash\n\n    docker run --rm -it -p 8080:8080 nbraun/dask-sql\n\nThis will spin up a server on port 8080 (by default).\nThe port and bind interfaces can be controlled with the ``--port`` and ``--host`` command line arguments (or options to :func:`~dask_sql.run_server`).\n\nThe running server looks similar to a normal presto database to any presto client and can therefore be used\nwith any library, e.g. the `presto CLI client <https://prestosql.io/docs/current/installation/cli.html>`_ or\n``sqlalchemy`` via the `PyHive <https://github.com/dropbox/PyHive#sqlalchemy>`_ package:\n\n.. code-block:: bash\n\n    presto --server localhost:8080\n\nNow you can fire simple SQL queries (as no data is loaded by default):\n\n.. code-block::\n\n    => SELECT 1 + 1;\n     EXPR$0\n    --------\n        2\n    (1 row)\n\nOr via ``sqlalchemy`` (after having installed ``PyHive``):\n\n.. code-block:: python\n\n    from sqlalchemy.engine import create_engine\n    engine = create_engine('presto://localhost:8080/')\n\n    import pandas as pd\n    pd.read_sql_query(\"SELECT 1 + 1\", con=engine)\n\nOf course, it is also possible to call the usual ``CREATE TABLE``\ncommands.\n\nPreregister your own data sources\n---------------------------------\n\nThe python function :func:`~dask_sql.run_server` accepts an already created :class:`~dask_sql.Context`.\nThis means you can preload your data sources and register them with a context before starting your server.\nBy this, your server will already have data to query:\n\n.. code-block:: python\n\n    from dask_sql import Context\n    c = Context()\n    c.create_table(...)\n\n    # Then spin up the ``dask-sql`` server\n    from dask_sql import run_server\n    run_server(context=c)\n\n\nRun it in your own ``dask`` cluster\n-----------------------------------\n\nThe SQL server implementation in ``dask-sql`` allows you to run a SQL server as a service connected to your ``dask`` cluster.\nThis enables your users to run SQL command leveraging the full power of your ``dask`` cluster without the need to write python code\nand allows also the usage of different non-python tools (such as BI tools) as long as they can speak the presto protocol.\n\nTo run a standalone SQL server in your ``dask`` cluster, follow these three steps:\n\n1. Create a startup script to connect ``dask-sql`` to your cluster.\n   There exist many different ways to connect to a ``dask`` cluster (e.g. direct access to the scheduler,\n   dask gateway, ...). Choose the one suitable for your cluster and create a small startup script:\n\n   .. code-block:: python\n\n        # Connect to your cluster here, e.g.\n        from dask.distributed import Client\n        client = Client(scheduler_address)\n\n        ...\n\n        # Then spin up the ``dask-sql`` server\n        from dask_sql import run_server\n        run_server(client=client)\n\n2. Deploy this script to your cluster as a service. How you do this, depends on your cluster infrastructure (kubernetes, mesos, openshift, ...).\n   For example you could create a docker image with a dockerfile similar to this:\n\n   .. code-block:: dockerfile\n\n        FROM nbraun/dask-sql\n\n        COPY continuous_integration/docker/startup_script.py /opt/dask_sql/startup_script.py\n\n        ENTRYPOINT [ \"/opt/conda/bin/python\", \"/opt/dask_sql/startup_script.py\" ]\n\n3. After your service is deployed, you can use it in your applications as a \"normal\" presto database.\n\nThe ``dask-sql`` SQL server was successfully tested with `Apache Hue <https://gethue.com/>`_, `Apache Superset <https://superset.apache.org/>`_\nand `Metabase <https://www.metabase.com/>`_.\n\n\nRunning from a jupyter notebook\n-------------------------------\n\nIf you quickly want to bridge the gap between your jupyter notebook and a BI tool,\nyou can run a temporary SQL server from your jupyter notebook.\n\n.. code-block:: python\n\n    # Create a Context and work with it\n    from dask_sql import Context\n    c = Context()\n\n    ...\n\n    # Later create a temporary server\n    c.run_server(blocking=False)\n\n    # Continue working\n\nThis allows you to access the same context with all its registered tables\nboth in the jupyter notebook as well as by connecting to the SQL server\nstarted on port 8080 (e.g. with your BI tool).\n\nOnce you are done with the SQL server, you can close it with\n\n.. code-block:: python\n\n    c.stop_server()\n\nPlease note that this feature should not be used for productive SQL servers,\nbut just for quick analyses via an external application.\n"
  },
  {
    "path": "docs/source/sql/creation.rst",
    "content": ".. _creation:\n\nTable Creation\n==============\n\nAs described in :ref:`quickstart`, it is possible to register an already\ncreated dask dataframe with a call to ``c.create_table``.\nHowever, it is also possible to load data directly from disk (or s3, hdfs, URL, hive, ...)\nand register it as a table in ``dask_sql``.\nBehind the scenes, a call to one of the ``read_<format>`` of the ``dask.dataframe``\nwill be executed.\nAdditionally, queries can be materialized into new tables for caching or faster access.\n\n.. raw:: html\n\n    <div class=\"highlight-sql notranslate\">\n    <div class=\"highlight\"><pre>\n    <span class=\"k\">CREATE</span> [ <span class=\"k\">OR REPLACE</span> ] <span class=\"k\">TABLE</span> [ <span class=\"k\">IF NOT EXISTS</span> ] <span class=\"ss\">&lt;table-name></span>\n        <span class=\"k\">WITH</span> ( <span class=\"ss\">&lt;key&gt;</span> = <span class=\"ss\">&lt;value&gt;</span> [ , ... ] )\n    <span class=\"k\">CREATE</span> [ <span class=\"k\">OR REPLACE</span> ] <span class=\"k\">TABLE</span> [ <span class=\"k\">IF NOT EXISTS</span> ] <span class=\"ss\">&lt;table-name></span>\n        <span class=\"k\">AS</span> ( <span class=\"k\">SELECT</span> ... )\n    <span class=\"k\">CREATE</span> [ <span class=\"k\">OR REPLACE</span> ] <span class=\"k\">VIEW</span> [ <span class=\"k\">IF NOT EXISTS</span> ] <span class=\"ss\">&lt;table-name></span>\n        <span class=\"k\">AS</span> ( <span class=\"k\">SELECT</span> ... )\n    <span class=\"k\">DROP TABLE</span> | <span class=\"k\">VIEW</span> [ <span class=\"k\">IF EXISTS</span> ] <span class=\"ss\">&lt;table-name></span>\n    </pre></div>\n    </div>\n\nSee :ref:`sql` for information on how to reference tables correctly.\nPlease note, that there can only ever exist a single view or table with the same name.\n\n.. note::\n\n    As there is only a single schema \"schema\" in ``dask-sql``,\n    table names should not include a separator \".\" in ``CREATE`` calls.\n\nBy default, if a table with the same name does already exist, ``dask-sql`` will raise an exception\n(and in turn will raise an exception if you try to delete a table which is not present).\nWith the flags ``IF [NOT] EXISTS`` and ``OR REPLACE``, this behavior can be controlled:\n\n* ``CREATE OR REPLACE TABLE | VIEW`` will override an already present table/view with the same name without raising an exception.\n* ``CREATE TABLE IF NOT EXISTS`` will not create the table/view if it already exists (and will also not raise an exception).\n* ``DROP TABLE | VIEW IF EXISTS`` will only drop the table/view if it exists and will not do anything otherwise.\n\n``CREATE TABLE WITH``\n---------------------\n\nThis will create and register a new table \"df\" with the data under the specified location\nand format.\nFor information on how to specify key-value arguments properly, see :ref:`sql`.\nWith the ``persist`` parameter, it can be controlled if the data should be cached\nor re-read for every SQL query.\nThe additional parameters are passed to the particular data loading functions.\nIf you omit the format argument, it will be deduced from the file name extension.\nMore ways to load data can be found in :ref:`data_input`.\n\nExample:\n\n.. raw:: html\n\n    <div class=\"highlight-sql notranslate\"><div class=\"highlight\"><pre><span></span><span class=\"k\">CREATE</span> <span class=\"k\">TABLE</span> <span class=\"n\">df</span> <span class=\"k\">WITH</span> <span class=\"p\">(</span>\n        <span class=\"n\">location</span> <span class=\"o\">=</span> <span class=\"ss\">\"/some/file/path\"</span><span class=\"p\">,</span>\n        <span class=\"n\">format</span> <span class=\"o\">=</span> <span class=\"ss\">\"csv/parquet/json/...\"</span><span class=\"p\">,</span>\n        <span class=\"n\">persist</span> <span class=\"o\">=</span> <span class=\"k\">True</span><span class=\"p\">,</span>\n        <span class=\"n\">additional_parameter</span> <span class=\"o\">=</span> <span class=\"n\">value</span><span class=\"p\">,</span>\n        <span class=\"p\">...</span>\n    <span class=\"p\">)</span>\n    </pre></div>\n    </div>\n\n``CREATE TABLE AS``\n-------------------\n\nUsing a similar syntax, it is also possible to create a (materialized) view of a (maybe complicated) SQL query.\nWith the command, you give the result of the ``SELECT`` query a name, that you can use\nin subsequent calls.\nThe ``SELECT`` can also contain a call to ``PREDICT``, see :ref:`ml`.\n\nExample:\n\n.. code-block:: sql\n\n    CREATE TABLE my_table AS (\n        SELECT\n            a, b, SUM(c)\n        FROM data\n        GROUP BY a, b\n        ...\n    )\n\n    SELECT * FROM my_table\n\n``CREATE VIEW AS``\n------------------\n\nInstead of using ``CREATE TABLE`` it is also possible to use ``CREATE VIEW``.\nThe result is very similar, the only difference is, *when* the result will be computed: a view is recomputed on every usage,\nwhereas a table is only calculated once on creation (also known as a materialized view).\nThis means, if you e.g. read data from a remote file and the file changes, a query containing a view will\nbe updated whereas a query with a table will stay as it is.\nTo update a table, you need to recreate it.\n\n.. hint::\n\n    Use views to simplify complicated queries (like a \"shortcut\") and tables for caching.\n\n.. note::\n\n    The update of the view only works, if your primary data source (the files you were reading in),\n    are not persisted during reading.\n\nExample:\n\n.. code-block:: sql\n\n    CREATE VIEW my_table AS (\n        SELECT\n            a, b, SUM(c)\n        FROM data\n        GROUP BY a, b\n        ...\n    )\n\n    SELECT * FROM my_table\n\n``DROP TABLE | VIEW``\n---------------------\n\nRemove a table or view with the given name.\nPlease note again, that views and tables are treated equally, so ``CREATE TABLE``\nwill also delete the view with the given name and vise versa.\n"
  },
  {
    "path": "docs/source/sql/describe.rst",
    "content": "Metadata Information\n====================\n\nWith these operations, it is possible to get information on the currently registered tables\nand their columns.\nThe output format is mostly compatible with the presto format.\n\n.. raw:: html\n\n    <div class=\"highlight\"><pre>\n    <span class=\"k\">SHOW SCHEMAS</span>\n    <span class=\"k\">SHOW TABLES FROM</span> <span class=\"ss\">&lt;schema-name&gt;</span>\n    <span class=\"k\">SHOW COLUMNS FROM</span> <span class=\"ss\">&lt;table-name></span>\n    <span class=\"k\">DESCRIBE</span> <span class=\"ss\">&lt;table-name></span>\n    <span class=\"k\">ANALYZE TABLE</span> <span class=\"ss\">&lt;table-name&gt;</span> <span class=\"k\">COMPUTE STATISTICS</span>\n        [ <span class=\"k\">FOR ALL COLUMNS</span> | <span class=\"k\">FOR COLUMNS</span> <span class=\"ss\">&lt;column&gt;</span>, [ ,... ] ]\n    </pre></div>\n\nSee :ref:`sql` for information on how to reference schemas and tables correctly.\n\n``SHOW SCHEMAS``\n----------------\n\nShow the schemas registered in ``dask-sql``.\nOnly included for compatibility reasons.\nThere is always just a one called \"schema\", where all the data is located and an additional schema, called \"information_schema\",\nwhich is needed by some BI tools (which is empty).\n\nExample:\n\n.. raw:: html\n\n    <div class=\"highlight\"><pre>\n    <span class=\"k\">SHOW SCHEMAS</span>\n    </pre></div>\n\nResult:\n\n+------------------------+\n| Schema                 |\n+========================+\n| schema                 |\n+------------------------+\n| information_schema     |\n+------------------------+\n\n``SHOW TABLES``\n---------------\n\nShow the registered tables in a given schema.\n\nExample:\n\n.. raw:: html\n\n    <div class=\"highlight\"><pre>\n    <span class=\"k\">SHOW TABLES FROM</span> <span class=\"ss\">\"schema\"</span>\n    </pre></div>\n\nResult:\n\n+------------+\n| Table      |\n+============+\n| timeseries |\n+------------+\n\n``SHOW COLUMNS`` and ``DESCRIBE``\n---------------------------------\n\nShow column information on a specific table.\n\nExample:\n\n.. raw:: html\n\n    <div class=\"highlight\"><pre>\n    <span class=\"k\">SHOW COLUMNS FROM</span> <span class=\"ss\">\"timeseries\"</span>\n    </pre></div>\n\nResult:\n\n+--------+---------+---------------+\n| Column |    Type | Extra Comment |\n+========+=========+===============+\n|     id |  bigint |               |\n+--------+---------+---------------+\n|   name | varchar |               |\n+--------+---------+---------------+\n|      x |  double |               |\n+--------+---------+---------------+\n|      y |  double |               |\n+--------+---------+---------------+\n\nThe column \"Extra Comment\" is shown for compatibility with presto.\n\n\n``ANALYZE TABLE``\n-----------------\n\nCalculate statistics on a given table (and the given columns or all columns)\nand return it as a query result.\nPlease note, that this process can be time consuming on large tables.\nEven though this statement is very similar to the ``ANALYZE TABLE`` statement in e.g. `Apache Spark <https://spark.apache.org/docs/3.0.0/sql-ref-syntax-aux-analyze-table.html>`_, it does not optimize subsequent queries (as the pendent in Spark will do).\n\nExample:\n\n.. raw:: html\n\n    <div class=\"highlight\"><pre>\n    <span class=\"k\">ANALYZE TABLE</span> <span class=\"ss\">\"timeseries\"</span> <span class=\"k\">COMPUTE STATISTICS</span> <span class=\"k\">FOR COLUMNS</span> <span class=\"ss\">x</span>, <span class=\"ss\">y</span>\n    </pre></div>\n\nResult:\n\n+-----------+-----------+-----------+\n|           |         x |         y |\n+===========+===========+===========+\n| count     |        30 |        30 |\n+-----------+-----------+-----------+\n| mean      |  0.140374 | -0.107481 |\n+-----------+-----------+-----------+\n| std       |  0.568248 |  0.573106 |\n+-----------+-----------+-----------+\n| min       | -0.795112 | -0.966043 |\n+-----------+-----------+-----------+\n| 25%       | -0.379635 | -0.561234 |\n+-----------+-----------+-----------+\n| 50%       | 0.0104101 | -0.237795 |\n+-----------+-----------+-----------+\n| 75%       |   0.70208 |  0.263459 |\n+-----------+-----------+-----------+\n| max       |  0.990747 |  0.947069 |\n+-----------+-----------+-----------+\n| data_type |    double |    double |\n+-----------+-----------+-----------+\n| col_name  |         x |         y |\n+-----------+-----------+-----------+\n"
  },
  {
    "path": "docs/source/sql/ml.rst",
    "content": ".. _ml:\n\nMachine Learning in SQL\n=======================\n\n.. note::\n    Machine Learning support is experimental in ``dask-sql``.\n    We encourage you to try it out and report any issues on our\n    `issue tracker <https://github.com/dask-contrib/dask-sql/issues>`_.\n\nAs all SQL statements in ``dask-sql`` are eventually converted to Python calls, it is very simple to include\nany custom Python function and library, e.g. Machine Learning libraries. Although it would be possible to\nregister custom functions (see :ref:`custom`) for this and use them, it is much more convenient if this functionality\nis already included in the core SQL language.\nThese three statements help in training and using models. Every :class:`~dask_sql.Context` has a registry for models, which\ncan be used for training or prediction.\nFor a full example, see :ref:`machine_learning`.\n\n.. raw:: html\n\n    <div class=\"highlight-sql notranslate\">\n    <div class=\"highlight\"><pre>\n    <span class=\"k\">CREATE</span> [ <span class=\"k\">OR REPLACE</span> ] <span class=\"k\">MODEL</span> [ <span class=\"k\">IF NOT EXISTS</span> ] <span class=\"ss\">&lt;model-name></span>\n        <span class=\"k\">WITH</span> ( <span class=\"ss\">&lt;key&gt;</span> = <span class=\"ss\">&lt;value&gt;</span> [ , ... ] ) <span class=\"k\">AS</span> ( <span class=\"k\">SELECT</span> ... )\n    <span class=\"k\">DROP MODEL</span> [ <span class=\"k\">IF EXISTS</span> ] <span class=\"ss\">&lt;model-name></span>\n    <span class=\"k\">SELECT</span> <span class=\"ss\">&lt;expression&gt;</span> <span class=\"k\">FROM PREDICT</span> (<span class=\"k\">MODEL</span> <span class=\"ss\">&lt;model-name></span>, <span class=\"k\">SELECT</span> ... )\n    </pre></div>\n    </div>\n\n``IF [ NOT ] EXISTS`` and ``CREATE OR REPLACE`` behave similar to its analogous flags in ``CREATE TABLE``.\nSee :ref:`creation` for more information.\n\n``CREATE MODEL``\n----------------\n\nCreate and train a model on the data from the given ``SELECT`` query\nand register it at the context.\n\nThe select query is a normal ``SELECT`` query (following the same syntax as described in :ref:`select`)\nor even a call to ``PREDICT`` (which typically does not make sense however) and its\nresult is used as the training data.\n\nThe key-value parameters control, how and which model is trained:\n\n    * ``model_class``:\n      This argument needs to be present.\n      It is the class name or full python module path to the class of the model to train.\n      Any sklearn, cuML, XGBoost, or LightGBM classes can be inferred\n      without the full path. In this case, models trained on cuDF dataframes\n      are automatically mapped to cuML classes, and sklearn models otherwise.\n      We map to cuML-Dask based models when possible and single-GPU cuML models otherwise.\n      Any model class with sklearn interface is valid, but might or\n      might not work well with Dask dataframes.\n      You might need to install necessary packages to use\n      the models.\n    * ``target_column``:\n      Which column from the data to use as target.\n      If not empty, it is removed automatically from\n      the training data. Defaults to an empty string, in which\n      case no target is feed to the model training (e.g. for\n      unsupervised algorithms). This means, you typically\n      want to set this parameter.\n    * ``wrap_predict``:\n      Boolean flag, whether to wrap the selected\n      model with a :class:`dask_sql.physical.rel.custom.wrappers.ParallelPostFit`.\n      Defaults to true for sklearn and single GPU cuML models and false otherwise.\n      Typically you set it to true for sklearn models if predicting on big data.\n    * ``wrap_fit``:\n      Boolean flag, whether to wrap the selected\n      model with a :class:`dask_sql.physical.rel.custom.wrappers.Incremental`.\n      Defaults to true for sklearn and single GPU cuML models and false otherwise.\n      Typically you set it to true for sklearn models if training on big data.\n    * ``fit_kwargs``:\n      keyword arguments sent to the call to ``fit()``.\n\nAll other arguments are passed to the constructor of the\nmodel class.\n\nExample:\n\n.. raw:: html\n\n    <div class=\"highlight-sql notranslate\"><div class=\"highlight\"><pre><span></span><span class=\"k\">CREATE MODEL</span> <span class=\"n\">my_model</span> <span class=\"k\">WITH</span> <span class=\"p\">(</span>\n        <span class=\"n\">model_class</span> <span class=\"o\">=</span> <span class=\"s1\">'XGBClassifier'</span><span class=\"p\">,</span>\n        <span class=\"n\">target_column</span> <span class=\"o\">=</span> <span class=\"s1\">'target'</span>\n    <span class=\"p\">)</span> <span class=\"k\">AS</span> <span class=\"p\">(</span>\n        <span class=\"k\">SELECT</span> <span class=\"n\">x</span><span class=\"p\">,</span> <span class=\"n\">y</span><span class=\"p\">,</span> <span class=\"n\">target</span>\n        <span class=\"k\">FROM</span> <span class=\"ss\">\"data\"</span>\n    <span class=\"p\">)</span>\n    </pre></div>\n    </div>\n\nThis SQL call is not a 1:1 replacement for a normal\npython training and can not fulfill all use-cases\nor requirements!\n\nIf you are dealing with large amounts of data,\nyou might run into problems while model training and/or\nprediction, depending if your model can cope with\ndask dataframes.\n\n    * if you are training on relatively small amounts\n      of data but predicting on large data samples,\n      you might want to set ``wrap_predict`` to True.\n      With this option, model interference will be\n      parallelized/distributed.\n    * If you are training on large amounts of data,\n      you can try setting wrap_fit to True. This will\n      do the same on the training step, but works only on\n      those models, which have a ``fit_partial`` method.\n\n\n``DROP MODEL``\n--------------\n\nRemove the model with the given name from the registered models.\n\n\n``SELECT FROM PREDICT``\n-----------------------\n\nPredict the target using the given model and dataframe from the ``SELECT`` query.\nThe return value is the input dataframe with an additional column named\n\"target\", which contains the predicted values.\nThe model needs to be registered at the context before using it in this function,\neither by calling :func:`~dask_sql.Context.register_model` explicitly or by training\na model using the ``CREATE MODEL`` SQL statement above.\n\nA model can be anything which has a ``predict`` function.\nPlease note however, that it will need to act on Dask dataframes. If you\nare using a model not optimized for this, it might be that you run out of memory if\nyour data is larger than the RAM of a single machine.\nTo prevent this, have a look into the :class:`dask_sql.physical.rel.custom.wrappers.ParallelPostFit`\nmeta-estimator. If you are using a model trained with ``CREATE MODEL``\nand the ``wrap_predict`` flag set to true, this is done automatically.\n\nUsing this SQL statement is roughly equivalent to doing\n\n.. code-block:: python\n\n    df = context.sql(\"<select query>\")\n    model = get the model from the context\n\n    target = model.predict(df)\n    return df.assign(target=target)\n\nThe select query is a normal ``SELECT`` query (following the same syntax as described in :ref:`select`)\nor even another a call to ``PREDICT``.\n"
  },
  {
    "path": "docs/source/sql/select.rst",
    "content": ".. _select:\n\nData Retrieval\n==============\n\nQuery data from already created tables. The ``SELECT`` call follows mostly the standard SQL conventions,\nincluding all typical ingredients (such as ``WHERE``, ``GROUP BY``, ``ORDER BY`` etc.).\n\n.. raw:: html\n\n    <div class=\"highlight-sql notranslate\"><div class=\"highlight\"><pre><span></span><span class=\"k\">SELECT</span> <span class=\"p\">[</span> <span class=\"k\">ALL</span> <span class=\"o\">|</span> <span class=\"k\">DISTINCT</span> <span class=\"p\">]</span>\n        <span class=\"o\">*</span> <span class=\"o\">|</span> <span class=\"ss\">&lt;expression&gt;</span> <span class=\"p\">[</span> <span class=\"p\">[</span> <span class=\"k\">AS</span> <span class=\"p\">]</span> <span class=\"ss\">&lt;alias&gt;</span> <span class=\"p\">]</span> <span class=\"p\">[</span> <span class=\"p\">,</span> <span class=\"p\">...</span> <span class=\"p\">]</span>\n        <span class=\"p\">[</span> <span class=\"k\">FROM</span> <span class=\"ss\">&lt;from&gt;</span> <span class=\"p\">[ ,</span> <span class=\"p\">...</span> <span class=\"p\">]</span> <span class=\"p\">]</span>\n        <span class=\"p\">[</span> <span class=\"k\">WHERE</span> <span class=\"ss\">&lt;filter-condition&gt;</span> <span class=\"p\">]</span>\n        <span class=\"p\">[</span> <span class=\"k\">GROUP</span> <span class=\"k\">BY</span> <span class=\"ss\">&lt;group-by&gt;</span> <span class=\"p\">]</span>\n        <span class=\"p\">[</span> <span class=\"k\">HAVING</span> <span class=\"ss\">&lt;having-condition&gt;</span> <span class=\"p\">]</span>\n        <span class=\"p\">[</span> <span class=\"k\">UNION</span> <span class=\"p\">[</span> <span class=\"k\">ALL</span> <span class=\"o\">|</span> <span class=\"k\">DISTINCT</span> <span class=\"p\">]</span> <span class=\"ss\">&lt;select&gt;</span> <span class=\"p\">]</span>\n        <span class=\"p\">[</span> <span class=\"k\">ORDER</span> <span class=\"k\">BY</span> <span class=\"ss\">&lt;order-by&gt;</span> <span class=\"p\">[</span> <span class=\"k\">ASC</span> <span class=\"o\">|</span> <span class=\"k\">DESC</span> <span class=\"p\">]</span> [</span> <span class=\"p\">,</span> <span class=\"p\">...</span> <span class=\"p\">] <span class=\"p\">]</span>\n        <span class=\"p\">[</span> <span class=\"k\">LIMIT</span> <span class=\"ss\">&lt;end&gt;</span> <span class=\"p\">]</span>\n        <span class=\"p\">[</span> <span class=\"k\">OFFSET</span> <span class=\"ss\">&lt;start&gt;</span> <span class=\"p\">]</span>\n    </pre></div>\n    </div>\n\n.. note::\n\n    If you would like to help, please see [our issue tracker](https://github.com/dask-contrib/dask-sql/issues/43).\n\nExample:\n\n.. code-block:: sql\n\n    SELECT\n        name, SUM(x) AS s\n    FROM\n        data\n    WHERE\n        y < 3 AND x > 0.5\n    GROUP BY\n        name\n    HAVING\n        SUM(x) < 5\n    UNION SELECT\n        'myself' AS name, 42 AS s\n    ORDER BY\n        s\n    LIMIT 100\n\nAlso (all kind of) joins and (complex) subqueries are possible:\n\n.. code-block:: sql\n\n    SELECT\n         lhs.name, lhs.id, lhs.x\n      FROM\n         data AS lhs\n      JOIN\n         (\n               SELECT\n                  name AS max_name,\n                  MAX(x) AS max_x\n               FROM timeseries\n               GROUP BY name\n         ) AS rhs\n      ON\n         lhs.name = rhs.max_name AND\n         lhs.x = rhs.max_x\n\nFor complex queries with many subqueries, it might be beneficial to use ``WITH``\nfor temporary table definitions:\n\n.. code-block:: sql\n\n    WITH tmp AS (\n        SELECT MAX(b) AS maxb from df GROUP BY a\n    )\n    SELECT\n        maxb\n    FROM tmp\n\nImplemented operations\n----------------------\n\nThe following list includes all operations understood and implemented in ``dask-sql``.\nScalar functions can be used to turn a column (or multiple) into a column of the same length (such as ``x + y`` or ``sin(x)``)\nwhereas aggregation functions can only be used in ``GROUP BY`` clauses, as they\nturn a column into a single value.\nFor more information on the semantic of the different functions, please have a look into the\n`Apache Calcite documentation <https://calcite.apache.org/docs/reference.html>`_.\n\nScalar Functions\n~~~~~~~~~~~~~~~~\n\nBinary Operations: ``AND``, ``OR``, ``>``, ``>=``, ``<``, ``<=``, ``=``, ``<>``, ``+``, ``-``, ``/``, ``*``\n\nUnary Math Operations: ``ABS``, ``ACOS``, ``ASIN``, ``ATAN``, ``ATAN2``, ``CBRT``, ``CEIL``, ``COS``, ``COT``, ``DEGREES``, ``EXP``, ``FLOOR``, ``LOG10``, ``LN``, ``POWER``, ``RADIANS``, ``ROUND``, ``SIGN``, ``SIN``, ``TAN``, ``TRUNCATE``\n\nString operations: ``LIKE``, ``SIMILAR TO``, ``||``, ``CHAR_LENGTH``, ``UPPER``, ``LOWER``, ``POSITION``, ``TRIM``, ``OVERLAY``, ``SUBSTRING``, ``INITCAP``\n\nDate operations: ``EXTRACT``, ``YEAR``, ``QUARTER``, ``MONTH``, ``WEEK``, ``DAYOFYEAR``, ``DAYOFMONTH``, ``DAYOFWEEK``, ``HOUR``, ``MINUTE``, ``SECOND``, ``LOCALTIME``, ``LOCALTIMESTAMP``, ``CURRENT_TIME``, ``CURRENT_DATE``, ``CURRENT_TIMESTAMP``\n\n.. note::\n\n    Due to a `bug/inconsistency <https://issues.apache.org/jira/browse/CALCITE-4313>`_ in Apache Calcite, both the ``CURRENTTIME`` and ``LOCALTIME`` return a time without timezone and are therefore the same functionality.\n\nSpecial Operations: ``CASE``, ``NOT``, ``IS NULL``, ``IS NOT NULL``, ``IS TRUE``, ``IS NOT TRUE``, ``IS FALSE:``, ``IS NOT FALSE``, ``IS UNKNOWN``, ``IS NOT UNKNOWN``, ``EXISTS``, ``RAND``, ``RAND_INTEGER``\n\nExample:\n\n.. code-block:: sql\n\n    SELECT\n        SIN(x)\n    FROM \"data\"\n    WHERE MONTH(t) = 4\n\n.. note::\n\n    It is also possible to implement custom functions. See :ref:`custom`.\n\nAggregations\n~~~~~~~~~~~~\n\n``ANY_VALUE``, ``AVG``, ``BIT_AND``, ``BIT_OR``, ``BIT_XOR``, ``COUNT``, ``EVERY``, ``MAX``, ``MIN``, ``SINGLE_VALUE``, ``STDDEV_POP``, ``STDDEV_SAMP``, ``SUM``, ``VAR_POP``, ``VAR_SAMP``, ``VARIANCE``\n\nExample:\n\n.. code-block:: sql\n\n    SELECT\n        SUM(x)\n    FROM \"data\"\n    GROUP BY y\n\nStatistical Aggregation Function which takes two columns as input are follows:\n\n``REGR_COUNT``, ``REGR_SXX``, ``REGR_SYY``, ``COVAR_POP``, ``COVAR_SAMP``\n\n.. code-block:: sql\n\n    SELECT\n        REGR_COUNT(y,x),\n        REGR_SXX(y,x),\n        COVAR_POP(y,x)\n    FROM \"data\"\n    GROUP BY z\n\n\n.. note::\n\n    It is also possible to implement custom aggregations. See :ref:`custom`.\n\nWindowing/Over\n~~~~~~~~~~~~~~\n\n``ROW_NUMBER``, ``SUM``, ``AVG``, ``COUNT``, ``MAX``, ``MIN``, ``SINGLE_VALUE``, ``FIRST_VALUE``, ``LAST_VALUE``\n\nExample:\n\n.. code-block:: sql\n\n    SELECT\n        y,\n        SUM(x) OVER (PARTITION BY z ORDER BY a NULLS FIRST)\n    FROM \"data\"\n\n.. note::\n\n    Again, it is also possible to implement custom windowing functions.\n\nTable Functions\n~~~~~~~~~~~~~~~\n\n``TABLESAMPLE SYSTEM`` and ``TABLESAMPLE BERNOULLI``:\n\nExample:\n\n.. code-block:: sql\n\n    SELECT * FROM \"data\" TABLESAMPLE BERNOULLI (20) REPEATABLE (42)\n\n``TABLESAMPLE`` allows to draw random samples from the given table and should be the preferred way\nto select samples. ``BERNOULLI`` will select a row in the original table with a probability\ngiven by the number in the brackets (in percentage). The optional flag ``REPEATABLE`` defines\nthe random seed to use.\n``SYSTEM`` is similar, but acts on partitions (so blocks of data) and is therefore much more\ninaccurate and should only ever be used on really large data samples where ``BERNOULLI`` is not\nfast enough (which is very unlikely).\n"
  },
  {
    "path": "docs/source/sql.rst",
    "content": ".. _sql:\n\nSQL Syntax\n==========\n\n``dask-sql`` understands SQL in (mostly) presto SQL syntax.\nSo far, not every valid SQL operator and keyword is already\nimplemented in ``dask-sql``, but a large fraction of it.\nHave a look into our `issue tracker <https://github.com/dask-contrib/dask-sql/issues>`_\nto find out what is still missing.\n\n``dask-sql`` understands queries for data retrieval (``SELECT``), queries on metadata information (``SHOW`` and ``DESCRIBE``), queries for table creation (``CREATE TABLE``) and machine learning (``CREATE MODEL`` and ``PREDICT``).\nIn the following, general information for these queries are given - the sub-pages give details on each of the implemented keywords or operations.\nThe information on these pages apply to all ways SQL queries can be handed over to ``dask-sql``: via Python (:ref:`api`), the SQL server (:ref:`server`) or the command line (:ref:`cmd`).\n\nGeneral\n-------\n\nData in ``dask-sql`` is - similar to most SQL systems - grouped in named tables, which consist of columns (with names and data types) and rows.\nThe tables are again grouped into schemas.\nFor simplicity, there only exists a single schema, named \"schema\".\n\nFor many queries, it is necessary to refer to a schema, table, or column.\nIdentifiers can be specified with double quotes or without quotes (if there is no ambiguity with SQL keywords).\nCasing will be kept (with or without quotes).\n\n.. code-block:: sql\n\n    SELECT\n        \"date\", \"name\"\n    FROM\n        \"df\"\n\n``\"date\"`` definitely needs quotation marks (as ``DATE`` is also an SQL keyword), but ``name`` and ``df`` can also be specified without quotation marks.\n\nTo prevent ambiguities, the full table identifier can be used:\n\n.. code-block:: sql\n\n    SELECT\n        \"df\".\"date\"\n    FROM\n        \"schema\".\"df\"\n\nIn many cases however, the bare name is enough:\n\n.. code-block:: sql\n\n    SELECT\n        \"date\"\n    FROM\n        \"df\"\n\nString literals get single quotes:\n\n.. code-block:: sql\n\n    SELECT 'string literal'\n\n.. note::\n    ``dask-sql`` can only understand a single SQL query per call to ``Context.sql``.\n    Therefore, there should also be no semicolons after the query.\n\n\nSome SQL statements, like ``CREATE MODEL WITH`` and ``CREATE TABLE WITH`` expect a list of key-value arguments,\nwhich resemble (not accidentally) a Python dictionary. They are in the form\n\n.. code-block:: none\n\n    (\n        key = value\n        [ , ... ]\n    )\n\nwith an arbitrary number of key-value pairs and always are enclosed in brackets. The keys are (similar to Pythons ``dict`` constructor) unquoted.\nA value can be any valid SQL literal (e.g. ``3``, ``4.2``, ``'string'``), a key-value parameter list itself\nor a list (``ARRAY``) or a set (``MULTISET``) (or another way of writing dictionaries with ``MAP``).\n\nThis means, the following is a valid key-value parameter list:\n\n.. code-block:: sql\n\n    (\n        first_argument = 3,\n        second_argument = MULTISET [ 1, 1, 2, 3 ],\n        third_argument = (\n            sub_argument_1 = ARRAY [ 1, 2, 3 ],\n            sub_argument_2 = 'a string'\n        )\n    )\n\nPlease note that, in contrast to python, no comma is allowed after the last argument.\n\nQuery Types and Reference\n-------------------------\n\n.. toctree::\n   :maxdepth: 1\n\n   sql/select.rst\n   sql/creation.rst\n   sql/ml.rst\n   sql/describe.rst\n\n\nImplemented Types\n-----------------\n\n``dask-sql`` needs to map between SQL and ``dask`` (Python) types.\nFor this, it uses the following mapping:\n\n+-----------------------+----------------+\n| From Python Type      | To SQL Type    |\n+=======================+================+\n| ``np.bool_``          |  ``BOOLEAN``   |\n+-----------------------+----------------+\n| ``np.datetime64``     |  ``TIMESTAMP`` |\n+-----------------------+----------------+\n| ``np.float32``        |  ``FLOAT``     |\n+-----------------------+----------------+\n| ``np.float64``        |  ``DOUBLE``    |\n+-----------------------+----------------+\n| ``np.int16``          |  ``SMALLINT``  |\n+-----------------------+----------------+\n| ``np.int32``          |  ``INTEGER``   |\n+-----------------------+----------------+\n| ``np.int64``          |  ``BIGINT``    |\n+-----------------------+----------------+\n| ``np.int8``           |  ``TINYINT``   |\n+-----------------------+----------------+\n| ``np.object_``        |  ``VARCHAR``   |\n+-----------------------+----------------+\n| ``np.uint16``         |  ``SMALLINT``  |\n+-----------------------+----------------+\n| ``np.uint32``         |  ``INTEGER``   |\n+-----------------------+----------------+\n| ``np.uint64``         |  ``BIGINT``    |\n+-----------------------+----------------+\n| ``np.uint8``          |  ``TINYINT``   |\n+-----------------------+----------------+\n| ``pd.BooleanDtype``   |  ``BOOLEAN``   |\n+-----------------------+----------------+\n| ``pd.Int16Dtype``     |  ``SMALLINT``  |\n+-----------------------+----------------+\n| ``pd.Int32Dtype``     |  ``INTEGER``   |\n+-----------------------+----------------+\n| ``pd.Int64Dtype``     |  ``BIGINT``    |\n+-----------------------+----------------+\n| ``pd.Int8Dtype``      |  ``TINYINT``   |\n+-----------------------+----------------+\n| ``pd.StringDtype``    |  ``VARCHAR``   |\n+-----------------------+----------------+\n| ``pd.UInt16Dtype``    |  ``SMALLINT``  |\n+-----------------------+----------------+\n| ``pd.UInt32Dtype``    |  ``INTEGER``   |\n+-----------------------+----------------+\n| ``pd.UInt64Dtype``    |  ``BIGINT``    |\n+-----------------------+----------------+\n| ``pd.UInt8Dtype``     |  ``TINYINT``   |\n+-----------------------+----------------+\n\n+-------------------+-----------------------------+\n| From SQL Type     | To Python Type              |\n+===================+=============================+\n| ``BIGINT``        |    ``pd.Int64Dtype``        |\n+-------------------+-----------------------------+\n| ``BOOLEAN``       |    ``pd.BooleanDtype``      |\n+-------------------+-----------------------------+\n| ``CHAR(*)``       |    ``pd.StringDtype``       |\n+-------------------+-----------------------------+\n| ``DATE``          |    ``np.dtype(\"<M8[ns]\")``  |\n+-------------------+-----------------------------+\n| ``DECIMAL(*)``    |    ``np.float64``           |\n+-------------------+-----------------------------+\n| ``DOUBLE``        |    ``np.float64``           |\n+-------------------+-----------------------------+\n| ``FLOAT``         |    ``np.float32``           |\n+-------------------+-----------------------------+\n| ``INTEGER``       |    ``pd.Int32Dtype()``      |\n+-------------------+-----------------------------+\n| ``INTERVAL``      |    ``np.dtype(\"<m8[ns]\")``  |\n+-------------------+-----------------------------+\n| ``SMALLINT``      |    ``pd.Int16Dtype()``      |\n+-------------------+-----------------------------+\n| ``TIME(*)``       |    ``np.dtype(\"<M8[ns]\")``  |\n+-------------------+-----------------------------+\n| ``TIMESTAMP(*)``  |    ``np.dtype(\"<M8[ns]\")``  |\n+-------------------+-----------------------------+\n| ``TINYINT``       |    ``pd.Int8Dtype``         |\n+-------------------+-----------------------------+\n| ``VARCHAR``       |    ``pd.StringDtype``       |\n+-------------------+-----------------------------+\n| ``VARCHAR(*)``    |    ``pd.StringDtype``       |\n+-------------------+-----------------------------+\n\nLimitatons\n----------\n\n``dask-sql`` is still in early development, therefore exist some limitations:\n\nNot all operations and aggregations are implemented already.\n\n.. note::\n    Whenever you find a not already implemented operation, keyword\n    or functionality, please raise an issue at our `issue tracker <https://github.com/dask-contrib/dask-sql/issues>`_ with your use-case.\n\nDask/pandas and SQL treat null-values (or nan) differently on sorting, grouping and joining.\n``dask-sql`` tries to follow the SQL standard as much as possible, so results might be different to what you expect from Dask/Pandas.\n\nApart from those functional limitations, there is a operation which need special care: ``ORDER BY```.\nNormally, ``dask-sql`` calls create a ``dask`` data frame, which gets only computed when you call the ``.compute()`` member.\nDue to internal constraints, this is currently not the case for ``ORDER BY``.\nIncluding this operation will trigger a calculation of the full data frame already when calling ``Context.sql()``.\n\n.. warning::\n    There is a subtle but important difference between adding ``LIMIT 10`` to your SQL query and calling ``sql(...).head(10)``.\n    The data inside ``dask`` is partitioned, to distribute it over the cluster.\n    ``head`` will only return the first N elements from the first partition - even if N is larger than the partition size.\n    As a benefit, calling ``.head(N)`` is typically faster than calculating the full data sample with ``.compute()``.\n    ``LIMIT`` on the other hand will always return the first N elements - no matter on how many partitions they are scattered -\n    but will also need to precalculate the first partition to find out, if it needs to have a look into all data or not.\n"
  },
  {
    "path": "notebooks/Custom Functions.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Custom Functions\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Apart from the SQL functions that are already implemented in `dask-sql`, it is possible to add custom functions and aggregations.\\n\",\n    \"Have a look into [the documentation](https://dask-sql.readthedocs.io/en/latest/pages/custom.html) for more information.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import numpy as np\\n\",\n    \"import dask.dataframe as dd\\n\",\n    \"import dask.datasets\\n\",\n    \"from dask_sql.context import Context\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"We use some generated test data for the notebook:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"c = Context()\\n\",\n    \"# Allows us to use the %%sql magic function\\n\",\n    \"c.ipython_magic()\\n\",\n    \"\\n\",\n    \"df = dask.datasets.timeseries().reset_index().persist()\\n\",\n    \"c.create_table(\\\"timeseries\\\", df)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"As a first step, we will create a scalar function to calculate the absolute value of a column.\\n\",\n    \"(Please note that this can also be done via the `ABS` function in SQL):\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# The input to the function will be a dask series\\n\",\n    \"def my_abs(x):\\n\",\n    \"    return x.abs()\\n\",\n    \"\\n\",\n    \"# As SQL is a typed language, we need to specify all types \\n\",\n    \"c.register_function(my_abs, \\\"MY_ABS\\\", parameters=[(\\\"x\\\", np.float64)], return_type=np.float64)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"We are now able to use our new function in all queries\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"%%sql\\n\",\n    \"    SELECT\\n\",\n    \"        x, y, MY_ABS(x) AS \\\"abs_x\\\", MY_ABS(y) AS \\\"abs_y\\\"\\n\",\n    \"    FROM\\n\",\n    \"        \\\"timeseries\\\"\\n\",\n    \"    WHERE\\n\",\n    \"        MY_ABS(x * y) > 0.5\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Next, we will register an aggregation, which gets a column as input and returns a single value.\\n\",\n    \"An aggregation needs to be an instance of `dask.Aggregation` (see the [dask docu](https://docs.dask.org/en/latest/dataframe-groupby.html#aggregate)).\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"my_sum = dd.Aggregation(\\\"MY_SUM\\\", lambda x: x.sum(), lambda x: x.sum())\\n\",\n    \"\\n\",\n    \"c.register_aggregation(my_sum, \\\"MY_SUM\\\", [(\\\"x\\\", np.float64)], np.float64)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"%%sql\\n\",\n    \"    SELECT\\n\",\n    \"        name, MY_SUM(x) AS \\\"my_sum\\\"\\n\",\n    \"    FROM\\n\",\n    \"        \\\"timeseries\\\"\\n\",\n    \"    GROUP BY\\n\",\n    \"        name\\n\",\n    \"    LIMIT 10\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.6\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "notebooks/Feature Overview.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"attachments\": {},\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Dask-SQL\\n\",\n    \"### A SQL Query Layer for Dask\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Introduction\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"`dask-sql` adds a SQL query layer on top of the Dask distributed Python library, which allows you to query your big and small data with SQL and still use the great power of the Dask ecosystem.\\n\",\n    \"It helps you combine the best of both worlds.\\n\",\n    \"See the [documentation](https://dask-sql.readthedocs.io/) for more information.\"\n   ]\n  },\n  {\n   \"attachments\": {},\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Starting Dask-SQL\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"There are two possibilities how you can send your SQL queries to `dask-sql`:\\n\",\n    \"* you use a Python notebook/script, such as the one you have currently opened\\n\",\n    \"* you run the [dask-sql Server](https://dask-sql.readthedocs.io/en/latest/pages/server.html) as a standalone application and connect to it via e.g. your BI tool\\n\",\n    \"\\n\",\n    \"We will stick with the first possibility in this notebook, but all SQL commands shown here can also be run via the SQL server.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Before we start, we need do import `dask-sql` and create a `Context`, which collects all the information on the currently registered data tables.\\n\",\n    \"We will also create a small local Dask cluster (this step is not needed, but gives us a bit more debugging options).\\n\",\n    \"If you have a large computation cluster, you can connect to it in this step (have a look [here](https://docs.dask.org/en/latest/setup.html)).\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from dask_sql import Context\\n\",\n    \"from dask.distributed import Client\\n\",\n    \"\\n\",\n    \"client = Client()\\n\",\n    \"c = Context()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"client\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"You are now ready to query with SQL!\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"c.sql(\\\"\\\"\\\"\\n\",\n    \"    SELECT 42 AS \\\"the answer\\\"\\n\",\n    \"\\\"\\\"\\\", return_futures=False)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Some shortcut for the following:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"c.ipython_magic(auto_include=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"This line allows us to write (instead of the line above)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"%%sql\\n\",\n    \"SELECT 42 AS \\\"the answer\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Data Input\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### 1. From a Dask Dataframe via Python\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import dask.dataframe as dd\\n\",\n    \"\\n\",\n    \"df = dd.read_csv(\\\"./iris.csv\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"df.head(10)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"c.create_table(\\\"iris\\\", df)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### 2. From an external data source via SQL\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"%%sql\\n\",\n    \"CREATE OR REPLACE TABLE iris\\n\",\n    \"WITH (\\n\",\n    \"    location = 'file://./iris.csv',\\n\",\n    \"    format = 'csv',\\n\",\n    \"    persist = True\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"attachments\": {},\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"* S3, Azure, DBFS, GS, hdfs, ...\\n\",\n    \"* Hive (experimental), Databricks (experimental), Intake\\n\",\n    \"* already loaded data persisted in your Dask cluster\\n\",\n    \"\\n\",\n    \"More [information](https://dask-sql.readthedocs.io/en/latest/pages/data_input.html)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### 3. As materialized Queries\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"%%sql\\n\",\n    \"CREATE OR REPLACE TABLE second_iris\\n\",\n    \"AS SELECT * FROM iris\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### 4. From the notebook\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"As we have created an ipython magic with `c.ipython_magic(auto_include=True)` we can even just reference any dataframe created in the notebook in our queries.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"my_data_frame = dd.read_csv(\\\"./iris.csv\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"%%sql\\n\",\n    \"SELECT * FROM my_data_frame LIMIT 10\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Please note that using this setting will automatically override any predefined tables with the same name.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Metadata Information\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"%%sql\\n\",\n    \"SHOW TABLES FROM \\\"schema\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"%%sql\\n\",\n    \"SHOW COLUMNS FROM \\\"iris\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"%%sql\\n\",\n    \"DESCRIBE iris\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"%%sql\\n\",\n    \"DESCRIBE TABLE iris\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Data Query\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"You can call \\\"normal\\\" SQL `SELECT` statements in `dask-sql`, with all typical components from the standard SQL language.\\n\",\n    \"More information in the [SQL reference](https://dask-sql.readthedocs.io/en/latest/pages/sql.html).\\n\",\n    \"`dask-sql` roughly follows the prestoSQL conventions (e.g. quoting).\\n\",\n    \"\\n\",\n    \"<div class=\\\"alert alert-info\\\">\\n\",\n    \"    \\n\",\n    \"#### Note\\n\",\n    \"    \\n\",\n    \"Not all SQL operators are implemented in `dask-sql` already.\\n\",\n    \"    \\n\",\n    \"</div>\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"%%sql\\n\",\n    \"SELECT * \\n\",\n    \"FROM iris\\n\",\n    \"LIMIT 10\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"%%sql\\n\",\n    \"SELECT \\n\",\n    \"    sepal_length + sepal_width AS \\\"sum\\\", \\n\",\n    \"    SIN(petal_length) AS \\\"sin\\\"\\n\",\n    \"FROM iris\\n\",\n    \"LIMIT 10\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"%%sql\\n\",\n    \"SELECT \\n\",\n    \"    species,\\n\",\n    \"    AVG(sepal_length) AS sepal_length, \\n\",\n    \"    AVG(sepal_width) AS sepal_width\\n\",\n    \"FROM iris\\n\",\n    \"GROUP BY species\\n\",\n    \"LIMIT 10\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"%%sql\\n\",\n    \"WITH maximal_values AS (\\n\",\n    \"    SELECT \\n\",\n    \"        species, \\n\",\n    \"        MAX(sepal_length) AS sepal_length\\n\",\n    \"    FROM iris\\n\",\n    \"    GROUP BY species\\n\",\n    \")\\n\",\n    \"SELECT lhs.*\\n\",\n    \"FROM iris AS lhs \\n\",\n    \"JOIN maximal_values AS rhs ON lhs.species = rhs.species AND lhs.sepal_length = rhs.sepal_length\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"print(c.explain(\\\"\\\"\\\"\\n\",\n    \"    WITH maximal_values AS (\\n\",\n    \"        SELECT \\n\",\n    \"            species, \\n\",\n    \"            MAX(sepal_length) AS sepal_length\\n\",\n    \"        FROM iris\\n\",\n    \"        GROUP BY species\\n\",\n    \"    )\\n\",\n    \"    SELECT \\n\",\n    \"        lhs.*\\n\",\n    \"    FROM iris AS lhs \\n\",\n    \"    JOIN maximal_values AS rhs\\n\",\n    \"    ON lhs.species = rhs.species \\n\",\n    \"        AND lhs.sepal_length = rhs.sepal_length\\n\",\n    \"\\\"\\\"\\\"))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Custom Functions\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import numpy as np\\n\",\n    \"\\n\",\n    \"def volume(length, width):\\n\",\n    \"    return (width / 2) ** 2 * np.pi * length\\n\",\n    \"\\n\",\n    \"# As SQL is a typed language, we need to specify all types \\n\",\n    \"c.register_function(volume, \\\"IRIS_VOLUME\\\", \\n\",\n    \"                    parameters=[(\\\"length\\\", np.float64), (\\\"width\\\", np.float64)], \\n\",\n    \"                    return_type=np.float64)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"%%sql\\n\",\n    \"SELECT \\n\",\n    \"    sepal_length, sepal_width, IRIS_VOLUME(sepal_length, sepal_width) AS volume\\n\",\n    \"FROM iris\\n\",\n    \"LIMIT 10\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Machine Learning\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"df.species.head(100)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"%%sql\\n\",\n    \"CREATE OR REPLACE TABLE enriched_iris AS (\\n\",\n    \"    SELECT \\n\",\n    \"        sepal_length, sepal_width, petal_length, petal_width,\\n\",\n    \"        CASE \\n\",\n    \"            WHEN species = 'setosa' THEN 0 ELSE CASE \\n\",\n    \"            WHEN species = 'versicolor' THEN 1\\n\",\n    \"            ELSE 2 \\n\",\n    \"        END END AS \\\"species\\\", \\n\",\n    \"        IRIS_VOLUME(sepal_length, sepal_width) AS volume\\n\",\n    \"    FROM iris \\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"%%sql\\n\",\n    \"CREATE OR REPLACE TABLE training_data AS (\\n\",\n    \"    SELECT \\n\",\n    \"        *\\n\",\n    \"    FROM enriched_iris\\n\",\n    \"    TABLESAMPLE BERNOULLI (50)\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"%%sql\\n\",\n    \"SELECT * FROM training_data\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"%%sql\\n\",\n    \"CREATE OR REPLACE MODEL my_model WITH (\\n\",\n    \"    model_class = 'DaskXGBClassifier',\\n\",\n    \"    target_column = 'species',\\n\",\n    \"    num_class = 3\\n\",\n    \") AS (\\n\",\n    \"    SELECT * FROM training_data\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"%%sql\\n\",\n    \"SHOW MODELS\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"%%sql\\n\",\n    \"DESCRIBE MODEL my_model\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"%%sql\\n\",\n    \"SELECT\\n\",\n    \"    *\\n\",\n    \"FROM PREDICT(\\n\",\n    \"    MODEL my_model,\\n\",\n    \"    SELECT * FROM enriched_iris\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"%%sql\\n\",\n    \"CREATE OR REPLACE TABLE results AS\\n\",\n    \"SELECT\\n\",\n    \"    *\\n\",\n    \"FROM PREDICT(\\n\",\n    \"    MODEL my_model,\\n\",\n    \"    TABLE enriched_iris\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Create Experiment \"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"- Tune single model with different Hyperparameters \\n\",\n    \"  - install **sklearn** for tuning\\n\",\n    \"- Tune multiple model with different Hyperparameters\\n\",\n    \"  - install **tpot** for Automl\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"%%sql\\n\",\n    \"CREATE EXPERIMENT my_exp WITH (\\n\",\n    \"        model_class = 'GradientBoostingClassifier',\\n\",\n    \"        experiment_class = 'GridSearchCV',\\n\",\n    \"        tune_parameters = (n_estimators = ARRAY [16, 32, 2],learning_rate = ARRAY [0.1,0.01,0.001],\\n\",\n    \"                           max_depth = ARRAY [3,4,5,10]),\\n\",\n    \"        target_column = 'species'\\n\",\n    \"    ) AS (\\n\",\n    \"            SELECT * FROM training_data\\n\",\n    \"        )\"\n   ]\n  },\n  {\n   \"attachments\": {},\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## AutoML in SQL\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"%%sql\\n\",\n    \"CREATE EXPERIMENT my_automl_exp WITH (\\n\",\n    \"            automl_class = 'tpot.TPOTClassifier',\\n\",\n    \"            automl_kwargs = (population_size = 2 ,generations=5,cv=2,n_jobs=-1,use_dask=True),\\n\",\n    \"            target_column = 'species'\\n\",\n    \"        ) AS (\\n\",\n    \"            SELECT * FROM training_data \\n\",\n    \"        )\\n\",\n    \"-- while model was training, checkout the visualization in dask-dashboard for each generation of tasks \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"%%sql \\n\",\n    \"show models\\n\",\n    \"\\n\",\n    \"-- once the experiment was completed, Best model was\\n\",\n    \"-- stored in context which can be used for prediction\"\n   ]\n  },\n  {\n   \"attachments\": {},\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Export ML models\\n\",\n    \"\\n\",\n    \"- export trained models and serve the model as microservice\\n\",\n    \"- supports Pickle, Joblib, MLflow formats\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"%%sql\\n\",\n    \"-- pickle  export\\n\",\n    \"EXPORT MODEL my_model with (\\n\",\n    \"            format ='pickle',\\n\",\n    \"            location = 'my_model.pkl'\\n\",\n    \"    )\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"%%sql\\n\",\n    \"-- joblib export\\n\",\n    \"EXPORT MODEL my_exp_GradientBoostingClassifier_best_model\\n\",\n    \"with (\\n\",\n    \"            format ='joblib',\\n\",\n    \"            location = 'best_mode.joblib'\\n\",\n    \"    )\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"%%sql\\n\",\n    \"-- mlflow export\\n\",\n    \"EXPORT MODEL automl_TPOTClassifier\\n\",\n    \"with (\\n\",\n    \"            format ='mlflow',\\n\",\n    \"            location = 'model_dir'\\n\",\n    \"    )\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"%%sql\\n\",\n    \"SELECT\\n\",\n    \"    target,\\n\",\n    \"    species,\\n\",\n    \"    COUNT(*)\\n\",\n    \"FROM\\n\",\n    \"    results\\n\",\n    \"GROUP BY target, species\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"t = c.sql(\\\"\\\"\\\"\\n\",\n    \"    SELECT\\n\",\n    \"        target,\\n\",\n    \"        species,\\n\",\n    \"        COUNT(*) AS \\\"number\\\"\\n\",\n    \"    FROM\\n\",\n    \"        results\\n\",\n    \"    GROUP BY target, species\\n\",\n    \"\\\"\\\"\\\").compute() \\n\",\n    \"t.set_index([\\\"target\\\", \\\"species\\\"]).unstack(\\\"species\\\").number.plot.bar()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"Python 3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.6\"\n  },\n  \"toc-autonumbering\": false,\n  \"toc-showcode\": false,\n  \"toc-showmarkdowntxt\": false\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "notebooks/FugueSQL.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"f39e2dbc-21a1-4d9a-bed7-e2bf2bd25bb8\",\n   \"metadata\": {},\n   \"source\": [\n    \"# FugueSQL Integrations\\n\",\n    \"\\n\",\n    \"[FugueSQL](https://fugue-tutorials.readthedocs.io/tutorials/fugue_sql/index.html) is a related project that aims to provide a unified SQL interface for a variety of different computing frameworks, including Dask.\\n\",\n    \"While it offers a SQL engine with a larger set of supported commands, this comes at the cost of slower performance when using Dask in comparison to dask-sql.\\n\",\n    \"In order to offer a \\\"best of both worlds\\\" solution, dask-sql can easily be integrated with FugueSQL, using its faster implementation of SQL commands when possible and falling back on FugueSQL's implementation when necessary.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"90e31400\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Setup\\n\",\n    \"\\n\",\n    \"FugueSQL offers the cell magic `%%fsql`, which can be used to define and execute queries entirely in SQL, with no need for external Python code!\\n\",\n    \"\\n\",\n    \"To use this cell magic, users must install [fugue-jupyter](https://pypi.org/project/fugue-jupyter/), which will additionally provide SQL syntax highlighting (note that the kernel must be restart after installing):\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"96c3ad1a\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"!pip install fugue-jupyter\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"ae79361a\",\n   \"metadata\": {},\n   \"source\": [\n    \"And run `fugue_jupyter.setup()` to register the magic:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"id\": \"2df05f5b\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from fugue_jupyter import setup\\n\",\n    \"\\n\",\n    \"setup()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"d3b8bfe5\",\n   \"metadata\": {},\n   \"source\": [\n    \"We will also start up a Dask client, which can be specified as an execution engine for FugueSQL queries:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"id\": \"a35d98e6-f24e-46c4-a4e6-b64d649d8ba7\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from dask.distributed import Client\\n\",\n    \"\\n\",\n    \"client = Client()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"bcb96523\",\n   \"metadata\": {},\n   \"source\": [\n    \"## dask-sql as a FugueSQL execution engine\\n\",\n    \"\\n\",\n    \"When dask-sql is installed, its `DaskSQLExecutionEngine` is automatically registered as the default engine for FugueSQL queries ran on Dask.\\n\",\n    \"We can then use it to run queries with the `%%fsql` cell magic, specifying `dask` as the execution engine to run the query on:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"id\": \"ff633572-ad08-4de1-8678-a8fbd09effd1\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/html\": [\n       \"<div>\\n\",\n       \"<style scoped>\\n\",\n       \"    .dataframe tbody tr th:only-of-type {\\n\",\n       \"        vertical-align: middle;\\n\",\n       \"    }\\n\",\n       \"\\n\",\n       \"    .dataframe tbody tr th {\\n\",\n       \"        vertical-align: top;\\n\",\n       \"    }\\n\",\n       \"\\n\",\n       \"    .dataframe thead th {\\n\",\n       \"        text-align: right;\\n\",\n       \"    }\\n\",\n       \"</style>\\n\",\n       \"<table border=\\\"1\\\" class=\\\"dataframe\\\">\\n\",\n       \"  <thead>\\n\",\n       \"    <tr style=\\\"text-align: right;\\\">\\n\",\n       \"      <th></th>\\n\",\n       \"      <th>a</th>\\n\",\n       \"    </tr>\\n\",\n       \"  </thead>\\n\",\n       \"  <tbody>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>0</th>\\n\",\n       \"      <td>xyz</td>\\n\",\n       \"    </tr>\\n\",\n       \"  </tbody>\\n\",\n       \"</table>\\n\",\n       \"</div>\"\n      ],\n      \"text/plain\": [\n       \"     a\\n\",\n       \"0  xyz\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/html\": [\n       \"<small>schema: a:str</small>\"\n      ],\n      \"text/plain\": [\n       \"<IPython.core.display.HTML object>\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"%%fsql dask\\n\",\n    \"\\n\",\n    \"CREATE [[\\\"xyz\\\"], [\\\"xxx\\\"]] SCHEMA a:str\\n\",\n    \"SELECT * WHERE a LIKE '%y%'\\n\",\n    \"PRINT\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"7f16b7d9-6b45-4caf-bbcb-63cc5d858556\",\n   \"metadata\": {},\n   \"source\": [\n    \"We can also use the `YIELD` keyword to register the results of our queries into Python objects:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"id\": \"521965bc-1a4c-49ab-b48f-789351cb24d4\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/html\": [\n       \"<div>\\n\",\n       \"<style scoped>\\n\",\n       \"    .dataframe tbody tr th:only-of-type {\\n\",\n       \"        vertical-align: middle;\\n\",\n       \"    }\\n\",\n       \"\\n\",\n       \"    .dataframe tbody tr th {\\n\",\n       \"        vertical-align: top;\\n\",\n       \"    }\\n\",\n       \"\\n\",\n       \"    .dataframe thead th {\\n\",\n       \"        text-align: right;\\n\",\n       \"    }\\n\",\n       \"</style>\\n\",\n       \"<table border=\\\"1\\\" class=\\\"dataframe\\\">\\n\",\n       \"  <thead>\\n\",\n       \"    <tr style=\\\"text-align: right;\\\">\\n\",\n       \"      <th></th>\\n\",\n       \"      <th>b</th>\\n\",\n       \"    </tr>\\n\",\n       \"  </thead>\\n\",\n       \"  <tbody>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>0</th>\\n\",\n       \"      <td>xyz</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>1</th>\\n\",\n       \"      <td>xxx-</td>\\n\",\n       \"    </tr>\\n\",\n       \"  </tbody>\\n\",\n       \"</table>\\n\",\n       \"</div>\"\n      ],\n      \"text/plain\": [\n       \"      b\\n\",\n       \"0   xyz\\n\",\n       \"1  xxx-\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/html\": [\n       \"<small>schema: b:str</small>\"\n      ],\n      \"text/plain\": [\n       \"<IPython.core.display.HTML object>\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"%%fsql dask\\n\",\n    \"src = CREATE [[\\\"xyz\\\"], [\\\"xxx\\\"]] SCHEMA a:str\\n\",\n    \"\\n\",\n    \"a = SELECT a AS b WHERE a LIKE '%y%'\\n\",\n    \"    YIELD DATAFRAME AS test\\n\",\n    \"\\n\",\n    \"b = SELECT CONCAT(a, '-') AS b FROM src WHERE a LIKE '%xx%'\\n\",\n    \"    YIELD DATAFRAME AS test1\\n\",\n    \"\\n\",\n    \"SELECT * FROM a UNION SELECT * FROM b\\n\",\n    \"PRINT\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"dfbb0a9a\",\n   \"metadata\": {},\n   \"source\": [\n    \"Which can then be interacted with outside of SQL:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"id\": \"79a3e87a-2764-410c-b257-c710c4a6c6d4\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/html\": [\n       \"<div><strong>Dask DataFrame Structure:</strong></div>\\n\",\n       \"<div>\\n\",\n       \"<style scoped>\\n\",\n       \"    .dataframe tbody tr th:only-of-type {\\n\",\n       \"        vertical-align: middle;\\n\",\n       \"    }\\n\",\n       \"\\n\",\n       \"    .dataframe tbody tr th {\\n\",\n       \"        vertical-align: top;\\n\",\n       \"    }\\n\",\n       \"\\n\",\n       \"    .dataframe thead th {\\n\",\n       \"        text-align: right;\\n\",\n       \"    }\\n\",\n       \"</style>\\n\",\n       \"<table border=\\\"1\\\" class=\\\"dataframe\\\">\\n\",\n       \"  <thead>\\n\",\n       \"    <tr style=\\\"text-align: right;\\\">\\n\",\n       \"      <th></th>\\n\",\n       \"      <th>b</th>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>npartitions=2</th>\\n\",\n       \"      <th></th>\\n\",\n       \"    </tr>\\n\",\n       \"  </thead>\\n\",\n       \"  <tbody>\\n\",\n       \"    <tr>\\n\",\n       \"      <th></th>\\n\",\n       \"      <td>object</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th></th>\\n\",\n       \"      <td>...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th></th>\\n\",\n       \"      <td>...</td>\\n\",\n       \"    </tr>\\n\",\n       \"  </tbody>\\n\",\n       \"</table>\\n\",\n       \"</div>\\n\",\n       \"<div>Dask Name: rename, 16 tasks</div>\"\n      ],\n      \"text/plain\": [\n       \"Dask DataFrame Structure:\\n\",\n       \"                    b\\n\",\n       \"npartitions=2        \\n\",\n       \"               object\\n\",\n       \"                  ...\\n\",\n       \"                  ...\\n\",\n       \"Dask Name: rename, 16 tasks\"\n      ]\n     },\n     \"execution_count\": 5,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"test.native  # a Dask DataFrame\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"id\": \"c98cb652-06e2-444a-b70a-fdd3de9ecd15\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/html\": [\n       \"<div>\\n\",\n       \"<style scoped>\\n\",\n       \"    .dataframe tbody tr th:only-of-type {\\n\",\n       \"        vertical-align: middle;\\n\",\n       \"    }\\n\",\n       \"\\n\",\n       \"    .dataframe tbody tr th {\\n\",\n       \"        vertical-align: top;\\n\",\n       \"    }\\n\",\n       \"\\n\",\n       \"    .dataframe thead th {\\n\",\n       \"        text-align: right;\\n\",\n       \"    }\\n\",\n       \"</style>\\n\",\n       \"<table border=\\\"1\\\" class=\\\"dataframe\\\">\\n\",\n       \"  <thead>\\n\",\n       \"    <tr style=\\\"text-align: right;\\\">\\n\",\n       \"      <th></th>\\n\",\n       \"      <th>b</th>\\n\",\n       \"    </tr>\\n\",\n       \"  </thead>\\n\",\n       \"  <tbody>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>1</th>\\n\",\n       \"      <td>xxx-</td>\\n\",\n       \"    </tr>\\n\",\n       \"  </tbody>\\n\",\n       \"</table>\\n\",\n       \"</div>\"\n      ],\n      \"text/plain\": [\n       \"      b\\n\",\n       \"1  xxx-\"\n      ]\n     },\n     \"execution_count\": 6,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"test1.native.compute()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"932ede31-90b2-49e5-9f4d-7cf1b8d919d2\",\n   \"metadata\": {},\n   \"source\": [\n    \"We can also run the equivalent of these queries in python code using `fugue_sql.fsql`, passing the Dask client into its `run` method to specify Dask as an execution engine:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"id\": \"c265b170-de4d-4fab-aeae-9f94031e960d\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/html\": [\n       \"<div>\\n\",\n       \"<style scoped>\\n\",\n       \"    .dataframe tbody tr th:only-of-type {\\n\",\n       \"        vertical-align: middle;\\n\",\n       \"    }\\n\",\n       \"\\n\",\n       \"    .dataframe tbody tr th {\\n\",\n       \"        vertical-align: top;\\n\",\n       \"    }\\n\",\n       \"\\n\",\n       \"    .dataframe thead th {\\n\",\n       \"        text-align: right;\\n\",\n       \"    }\\n\",\n       \"</style>\\n\",\n       \"<table border=\\\"1\\\" class=\\\"dataframe\\\">\\n\",\n       \"  <thead>\\n\",\n       \"    <tr style=\\\"text-align: right;\\\">\\n\",\n       \"      <th></th>\\n\",\n       \"      <th>a</th>\\n\",\n       \"    </tr>\\n\",\n       \"  </thead>\\n\",\n       \"  <tbody>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>0</th>\\n\",\n       \"      <td>xyz</td>\\n\",\n       \"    </tr>\\n\",\n       \"  </tbody>\\n\",\n       \"</table>\\n\",\n       \"</div>\"\n      ],\n      \"text/plain\": [\n       \"     a\\n\",\n       \"0  xyz\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/html\": [\n       \"<small>schema: a:str</small>\"\n      ],\n      \"text/plain\": [\n       \"<IPython.core.display.HTML object>\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"DataFrames()\"\n      ]\n     },\n     \"execution_count\": 7,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"from fugue_sql import fsql\\n\",\n    \"\\n\",\n    \"fsql(\\\"\\\"\\\"\\n\",\n    \"CREATE [[\\\"xyz\\\"], [\\\"xxx\\\"]] SCHEMA a:str\\n\",\n    \"SELECT * WHERE a LIKE '%y%'\\n\",\n    \"PRINT\\n\",\n    \"\\\"\\\"\\\").run(client)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"id\": \"77e3bf50-8c8b-4e2f-a5e7-28b1d86499d7\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/html\": [\n       \"<div><strong>Dask DataFrame Structure:</strong></div>\\n\",\n       \"<div>\\n\",\n       \"<style scoped>\\n\",\n       \"    .dataframe tbody tr th:only-of-type {\\n\",\n       \"        vertical-align: middle;\\n\",\n       \"    }\\n\",\n       \"\\n\",\n       \"    .dataframe tbody tr th {\\n\",\n       \"        vertical-align: top;\\n\",\n       \"    }\\n\",\n       \"\\n\",\n       \"    .dataframe thead th {\\n\",\n       \"        text-align: right;\\n\",\n       \"    }\\n\",\n       \"</style>\\n\",\n       \"<table border=\\\"1\\\" class=\\\"dataframe\\\">\\n\",\n       \"  <thead>\\n\",\n       \"    <tr style=\\\"text-align: right;\\\">\\n\",\n       \"      <th></th>\\n\",\n       \"      <th>a</th>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th>npartitions=2</th>\\n\",\n       \"      <th></th>\\n\",\n       \"    </tr>\\n\",\n       \"  </thead>\\n\",\n       \"  <tbody>\\n\",\n       \"    <tr>\\n\",\n       \"      <th></th>\\n\",\n       \"      <td>object</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th></th>\\n\",\n       \"      <td>...</td>\\n\",\n       \"    </tr>\\n\",\n       \"    <tr>\\n\",\n       \"      <th></th>\\n\",\n       \"      <td>...</td>\\n\",\n       \"    </tr>\\n\",\n       \"  </tbody>\\n\",\n       \"</table>\\n\",\n       \"</div>\\n\",\n       \"<div>Dask Name: rename, 16 tasks</div>\"\n      ],\n      \"text/plain\": [\n       \"Dask DataFrame Structure:\\n\",\n       \"                    a\\n\",\n       \"npartitions=2        \\n\",\n       \"               object\\n\",\n       \"                  ...\\n\",\n       \"                  ...\\n\",\n       \"Dask Name: rename, 16 tasks\"\n      ]\n     },\n     \"execution_count\": 8,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"result = fsql(\\\"\\\"\\\"\\n\",\n    \"CREATE [[\\\"xyz\\\"], [\\\"xxx\\\"]] SCHEMA a:str\\n\",\n    \"SELECT * WHERE a LIKE '%y%'\\n\",\n    \"YIELD DATAFRAME AS test2\\n\",\n    \"\\\"\\\"\\\").run(client)\\n\",\n    \"\\n\",\n    \"result[\\\"test2\\\"].native  # a Dask DataFrame\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"7d4c71d4-238f-4c72-8609-dbbe0782aea9\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.9.13\"\n  },\n  \"vscode\": {\n   \"interpreter\": {\n    \"hash\": \"656801d214ad98d4b301386b078628ce3ae2dbd81a59ed4deed7a5b13edfab09\"\n   }\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "notebooks/iris.csv",
    "content": "sepal_length,sepal_width,petal_length,petal_width,species\n5.1,3.5,1.4,0.2,setosa\n4.9,3.0,1.4,0.2,setosa\n4.7,3.2,1.3,0.2,setosa\n4.6,3.1,1.5,0.2,setosa\n5.0,3.6,1.4,0.2,setosa\n5.4,3.9,1.7,0.4,setosa\n4.6,3.4,1.4,0.3,setosa\n5.0,3.4,1.5,0.2,setosa\n4.4,2.9,1.4,0.2,setosa\n4.9,3.1,1.5,0.1,setosa\n5.4,3.7,1.5,0.2,setosa\n4.8,3.4,1.6,0.2,setosa\n4.8,3.0,1.4,0.1,setosa\n4.3,3.0,1.1,0.1,setosa\n5.8,4.0,1.2,0.2,setosa\n5.7,4.4,1.5,0.4,setosa\n5.4,3.9,1.3,0.4,setosa\n5.1,3.5,1.4,0.3,setosa\n5.7,3.8,1.7,0.3,setosa\n5.1,3.8,1.5,0.3,setosa\n5.4,3.4,1.7,0.2,setosa\n5.1,3.7,1.5,0.4,setosa\n4.6,3.6,1.0,0.2,setosa\n5.1,3.3,1.7,0.5,setosa\n4.8,3.4,1.9,0.2,setosa\n5.0,3.0,1.6,0.2,setosa\n5.0,3.4,1.6,0.4,setosa\n5.2,3.5,1.5,0.2,setosa\n5.2,3.4,1.4,0.2,setosa\n4.7,3.2,1.6,0.2,setosa\n4.8,3.1,1.6,0.2,setosa\n5.4,3.4,1.5,0.4,setosa\n5.2,4.1,1.5,0.1,setosa\n5.5,4.2,1.4,0.2,setosa\n4.9,3.1,1.5,0.1,setosa\n5.0,3.2,1.2,0.2,setosa\n5.5,3.5,1.3,0.2,setosa\n4.9,3.1,1.5,0.1,setosa\n4.4,3.0,1.3,0.2,setosa\n5.1,3.4,1.5,0.2,setosa\n5.0,3.5,1.3,0.3,setosa\n4.5,2.3,1.3,0.3,setosa\n4.4,3.2,1.3,0.2,setosa\n5.0,3.5,1.6,0.6,setosa\n5.1,3.8,1.9,0.4,setosa\n4.8,3.0,1.4,0.3,setosa\n5.1,3.8,1.6,0.2,setosa\n4.6,3.2,1.4,0.2,setosa\n5.3,3.7,1.5,0.2,setosa\n5.0,3.3,1.4,0.2,setosa\n7.0,3.2,4.7,1.4,versicolor\n6.4,3.2,4.5,1.5,versicolor\n6.9,3.1,4.9,1.5,versicolor\n5.5,2.3,4.0,1.3,versicolor\n6.5,2.8,4.6,1.5,versicolor\n5.7,2.8,4.5,1.3,versicolor\n6.3,3.3,4.7,1.6,versicolor\n4.9,2.4,3.3,1.0,versicolor\n6.6,2.9,4.6,1.3,versicolor\n5.2,2.7,3.9,1.4,versicolor\n5.0,2.0,3.5,1.0,versicolor\n5.9,3.0,4.2,1.5,versicolor\n6.0,2.2,4.0,1.0,versicolor\n6.1,2.9,4.7,1.4,versicolor\n5.6,2.9,3.6,1.3,versicolor\n6.7,3.1,4.4,1.4,versicolor\n5.6,3.0,4.5,1.5,versicolor\n5.8,2.7,4.1,1.0,versicolor\n6.2,2.2,4.5,1.5,versicolor\n5.6,2.5,3.9,1.1,versicolor\n5.9,3.2,4.8,1.8,versicolor\n6.1,2.8,4.0,1.3,versicolor\n6.3,2.5,4.9,1.5,versicolor\n6.1,2.8,4.7,1.2,versicolor\n6.4,2.9,4.3,1.3,versicolor\n6.6,3.0,4.4,1.4,versicolor\n6.8,2.8,4.8,1.4,versicolor\n6.7,3.0,5.0,1.7,versicolor\n6.0,2.9,4.5,1.5,versicolor\n5.7,2.6,3.5,1.0,versicolor\n5.5,2.4,3.8,1.1,versicolor\n5.5,2.4,3.7,1.0,versicolor\n5.8,2.7,3.9,1.2,versicolor\n6.0,2.7,5.1,1.6,versicolor\n5.4,3.0,4.5,1.5,versicolor\n6.0,3.4,4.5,1.6,versicolor\n6.7,3.1,4.7,1.5,versicolor\n6.3,2.3,4.4,1.3,versicolor\n5.6,3.0,4.1,1.3,versicolor\n5.5,2.5,4.0,1.3,versicolor\n5.5,2.6,4.4,1.2,versicolor\n6.1,3.0,4.6,1.4,versicolor\n5.8,2.6,4.0,1.2,versicolor\n5.0,2.3,3.3,1.0,versicolor\n5.6,2.7,4.2,1.3,versicolor\n5.7,3.0,4.2,1.2,versicolor\n5.7,2.9,4.2,1.3,versicolor\n6.2,2.9,4.3,1.3,versicolor\n5.1,2.5,3.0,1.1,versicolor\n5.7,2.8,4.1,1.3,versicolor\n6.3,3.3,6.0,2.5,virginica\n5.8,2.7,5.1,1.9,virginica\n7.1,3.0,5.9,2.1,virginica\n6.3,2.9,5.6,1.8,virginica\n6.5,3.0,5.8,2.2,virginica\n7.6,3.0,6.6,2.1,virginica\n4.9,2.5,4.5,1.7,virginica\n7.3,2.9,6.3,1.8,virginica\n6.7,2.5,5.8,1.8,virginica\n7.2,3.6,6.1,2.5,virginica\n6.5,3.2,5.1,2.0,virginica\n6.4,2.7,5.3,1.9,virginica\n6.8,3.0,5.5,2.1,virginica\n5.7,2.5,5.0,2.0,virginica\n5.8,2.8,5.1,2.4,virginica\n6.4,3.2,5.3,2.3,virginica\n6.5,3.0,5.5,1.8,virginica\n7.7,3.8,6.7,2.2,virginica\n7.7,2.6,6.9,2.3,virginica\n6.0,2.2,5.0,1.5,virginica\n6.9,3.2,5.7,2.3,virginica\n5.6,2.8,4.9,2.0,virginica\n7.7,2.8,6.7,2.0,virginica\n6.3,2.7,4.9,1.8,virginica\n6.7,3.3,5.7,2.1,virginica\n7.2,3.2,6.0,1.8,virginica\n6.2,2.8,4.8,1.8,virginica\n6.1,3.0,4.9,1.8,virginica\n6.4,2.8,5.6,2.1,virginica\n7.2,3.0,5.8,1.6,virginica\n7.4,2.8,6.1,1.9,virginica\n7.9,3.8,6.4,2.0,virginica\n6.4,2.8,5.6,2.2,virginica\n6.3,2.8,5.1,1.5,virginica\n6.1,2.6,5.6,1.4,virginica\n7.7,3.0,6.1,2.3,virginica\n6.3,3.4,5.6,2.4,virginica\n6.4,3.1,5.5,1.8,virginica\n6.0,3.0,4.8,1.8,virginica\n6.9,3.1,5.4,2.1,virginica\n6.7,3.1,5.6,2.4,virginica\n6.9,3.1,5.1,2.3,virginica\n5.8,2.7,5.1,1.9,virginica\n6.8,3.2,5.9,2.3,virginica\n6.7,3.3,5.7,2.5,virginica\n6.7,3.0,5.2,2.3,virginica\n6.3,2.5,5.0,1.9,virginica\n6.5,3.0,5.2,2.0,virginica\n6.2,3.4,5.4,2.3,virginica\n5.9,3.0,5.1,1.8,virginica\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires = [\"maturin>=1.3,<1.4\"]\nbuild-backend = \"maturin\"\n\n[project]\nname = \"dask_sql\"\ndescription = \"SQL query layer for Dask\"\nmaintainers = [{name = \"Nils Braun\", email = \"nilslennartbraun@gmail.com\"}]\nlicense = {text = \"MIT\"}\nclassifiers = [\n    \"Development Status :: 5 - Production/Stable\",\n    \"Intended Audience :: Developers\",\n    \"Intended Audience :: Science/Research\",\n    \"License :: OSI Approved :: MIT License\",\n    \"Operating System :: OS Independent\",\n    \"Programming Language :: Rust\",\n    \"Programming Language :: Python\",\n    \"Programming Language :: Python :: 3\",\n    \"Programming Language :: Python :: 3 :: Only\",\n    \"Programming Language :: Python :: 3.9\",\n    \"Programming Language :: Python :: 3.10\",\n    \"Programming Language :: Python :: 3.11\",\n    \"Programming Language :: Python :: 3.12\",\n    \"Topic :: Scientific/Engineering\",\n    \"Topic :: System :: Distributed Computing\",\n]\nreadme = \"README.md\"\nrequires-python = \">=3.9\"\ndependencies = [\n    \"dask[dataframe]>=2024.4.1\",\n    \"distributed>=2024.4.1\",\n    \"pandas>=1.4.0\",\n    \"fastapi>=0.92.0\",\n    \"httpx>=0.24.1\",\n    \"uvicorn>=0.14\",\n    \"tzlocal>=2.1\",\n    \"prompt_toolkit>=3.0.8\",\n    \"pygments>=2.7.1\",\n    \"tabulate\",\n]\ndynamic = [\"version\"]\n\n[project.urls]\nHomepage = \"https://github.com/dask-contrib/dask-sql\"\nDocumentation = \"https://dask-sql.readthedocs.io\"\nSource = \"https://github.com/dask-contrib/dask-sql\"\n\n[project.optional-dependencies]\ndev = [\n    \"pytest>=6.0.1\",\n    \"pytest-cov>=2.10.1\",\n    \"mock>=4.0.3\",\n    \"sphinx>=3.2.1\",\n    \"pyarrow>=14.0.1\",\n    \"scikit-learn>=1.0.0\",\n    \"intake>=0.6.0\",\n    \"pre-commit\",\n    \"black==22.10.0\",\n    \"isort==5.12.0\",\n]\nfugue = [\n    \"fugue>=0.7.3\",\n    # FIXME: https://github.com/fugue-project/fugue/issues/526\n    \"triad<0.9.2\",\n]\n\n[project.entry-points.\"fugue.plugins\"]\ndasksql = \"dask_sql.integrations.fugue:_register_engines[fugue]\"\n\n[project.scripts]\ndask-sql = \"dask_sql.cmd:main\"\ndask-sql-server = \"dask_sql.server.app:main\"\n\n[tool.setuptools]\ninclude-package-data = true\nzip-safe = false\nlicense-files = [\"LICENSE.txt\"]\n\n[tool.setuptools.packages]\nfind = {namespaces = false}\n\n[tool.maturin]\nmodule-name = \"dask_sql._datafusion_lib\"\ninclude = [\n    { path = \"Cargo.lock\", format = \"sdist\" }\n]\nexclude = [\".github/**\", \"continuous_integration/**\"]\nlocked = true\n\n[tool.isort]\nprofile = \"black\"\n\n[tool.pytest.ini_options]\nmarkers = [\n    \"gpu: marks tests that require GPUs (skipped by default, run with --rungpu)\",\n    \"queries: marks tests that run test queries (skipped by default, run with --runqueries)\",\n]\naddopts = \"-v -rsxfE --color=yes --cov dask_sql --cov-config=.coveragerc --cov-report=term-missing\"\nfilterwarnings = [\n    \"error:::dask_sql[.*]\",\n    \"error:::dask[.*]\",\n    \"ignore:Need to do a cross-join:ResourceWarning:dask_sql[.*]\",\n    \"ignore:Dask doesn't support Dask frames:ResourceWarning:dask_sql[.*]\",\n    \"ignore:Running on a single-machine scheduler:UserWarning:dask[.*]\",\n    \"ignore:Merging dataframes with merge column data type mismatches:UserWarning:dask[.*]\",\n]\nxfail_strict = true\n"
  },
  {
    "path": "rustfmt.toml",
    "content": "imports_layout = \"HorizontalVertical\"\nimports_granularity = \"Crate\"\ngroup_imports = \"StdExternalCrate\"\n"
  },
  {
    "path": "setup.cfg",
    "content": "[flake8]\n# References:\n# https://flake8.readthedocs.io/en/latest/user/configuration.html\n# https://flake8.readthedocs.io/en/latest/user/error-codes.html\n# https://pycodestyle.pycqa.org/en/latest/intro.html#error-codes\nexclude = __init__.py\nignore =\n    E203,  # whitespace before ':'\n    E231,E241,  # Multiple spaces around \",\"\n    E731,  # Assigning lambda expression\n    #E741,  # Ambiguous variable names\n    W503,  # line break before binary operator\n    W504,  # line break after binary operator\n    ; F821,  # undefined name\nper-file-ignores =\n    tests/*:\n        # local variable is assigned to but never used\n        F841,\n        # Ambiguous variable name\n        E741,\nmax-line-length = 150\n"
  },
  {
    "path": "src/dialect.rs",
    "content": "use core::{iter::Peekable, str::Chars};\n\nuse datafusion_python::datafusion_sql::sqlparser::{\n    ast::{Expr, Function, FunctionArg, FunctionArgExpr, Ident, ObjectName, Value},\n    dialect::Dialect,\n    keywords::Keyword,\n    parser::{Parser, ParserError},\n    tokenizer::Token,\n};\n\n#[derive(Debug)]\npub struct DaskDialect {}\n\nimpl Dialect for DaskDialect {\n    fn is_identifier_start(&self, ch: char) -> bool {\n        // See https://www.postgresql.org/docs/11/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS\n        // We don't yet support identifiers beginning with \"letters with\n        // diacritical marks and non-Latin letters\"\n        ch.is_ascii_lowercase() || ch.is_ascii_uppercase() || ch == '_'\n    }\n\n    fn is_identifier_part(&self, ch: char) -> bool {\n        ch.is_ascii_lowercase()\n            || ch.is_ascii_uppercase()\n            || ch.is_ascii_digit()\n            || ch == '$'\n            || ch == '_'\n    }\n\n    /// Determine if a character starts a quoted identifier. The default\n    /// implementation, accepting \"double quoted\" ids is both ANSI-compliant\n    /// and appropriate for most dialects (with the notable exception of\n    /// MySQL, MS SQL, and sqlite). You can accept one of characters listed\n    /// in `Word::matching_end_quote` here\n    fn is_delimited_identifier_start(&self, ch: char) -> bool {\n        ch == '\"'\n    }\n    /// Determine if quoted characters are proper for identifier\n    fn is_proper_identifier_inside_quotes(&self, mut _chars: Peekable<Chars<'_>>) -> bool {\n        true\n    }\n    /// Determine if FILTER (WHERE ...) filters are allowed during aggregations\n    fn supports_filter_during_aggregation(&self) -> bool {\n        true\n    }\n\n    /// override expression parsing\n    fn parse_prefix(&self, parser: &mut Parser) -> Option<Result<Expr, ParserError>> {\n        fn parse_expr(parser: &mut Parser) -> Result<Option<Expr>, ParserError> {\n            match parser.peek_token().token {\n                Token::Word(w) if w.value.to_lowercase() == \"ceil\" => {\n                    // CEIL(d TO DAY)\n                    parser.next_token(); // skip ceil\n                    parser.expect_token(&Token::LParen)?;\n                    let expr = parser.parse_expr()?;\n                    if !parser.parse_keyword(Keyword::TO) {\n                        // Parse CEIL(expr) as normal\n                        parser.prev_token();\n                        parser.prev_token();\n                        parser.prev_token();\n                        return Ok(None);\n                    }\n                    let time_unit = parser.next_token();\n                    parser.expect_token(&Token::RParen)?;\n\n                    // convert to function args\n                    let args = vec![\n                        FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)),\n                        FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(\n                            Value::SingleQuotedString(time_unit.to_string()),\n                        ))),\n                    ];\n\n                    Ok(Some(Expr::Function(Function {\n                        name: ObjectName(vec![Ident::new(\"timestampceil\")]),\n                        args,\n                        over: None,\n                        distinct: false,\n                        special: false,\n                        order_by: vec![],\n                    })))\n                }\n                Token::Word(w) if w.value.to_lowercase() == \"floor\" => {\n                    // FLOOR(d TO DAY)\n                    parser.next_token(); // skip floor\n                    parser.expect_token(&Token::LParen)?;\n                    let expr = parser.parse_expr()?;\n                    if !parser.parse_keyword(Keyword::TO) {\n                        // Parse FLOOR(expr) as normal\n                        parser.prev_token();\n                        parser.prev_token();\n                        parser.prev_token();\n                        return Ok(None);\n                    }\n                    let time_unit = parser.next_token();\n                    parser.expect_token(&Token::RParen)?;\n\n                    // convert to function args\n                    let args = vec![\n                        FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)),\n                        FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(\n                            Value::SingleQuotedString(time_unit.to_string()),\n                        ))),\n                    ];\n\n                    Ok(Some(Expr::Function(Function {\n                        name: ObjectName(vec![Ident::new(\"timestampfloor\")]),\n                        args,\n                        over: None,\n                        distinct: false,\n                        special: false,\n                        order_by: vec![],\n                    })))\n                }\n                Token::Word(w) if w.value.to_lowercase() == \"timestampadd\" => {\n                    // TIMESTAMPADD(YEAR, 2, d)\n                    parser.next_token(); // skip timestampadd\n                    parser.expect_token(&Token::LParen)?;\n                    let time_unit = parser.next_token();\n                    parser.expect_token(&Token::Comma)?;\n                    let n = parser.parse_expr()?;\n                    parser.expect_token(&Token::Comma)?;\n                    let expr = parser.parse_expr()?;\n                    parser.expect_token(&Token::RParen)?;\n\n                    // convert to function args\n                    let args = vec![\n                        FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(\n                            Value::SingleQuotedString(time_unit.to_string()),\n                        ))),\n                        FunctionArg::Unnamed(FunctionArgExpr::Expr(n)),\n                        FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)),\n                    ];\n\n                    Ok(Some(Expr::Function(Function {\n                        name: ObjectName(vec![Ident::new(\"timestampadd\")]),\n                        args,\n                        over: None,\n                        distinct: false,\n                        special: false,\n                        order_by: vec![],\n                    })))\n                }\n                Token::Word(w) if w.value.to_lowercase() == \"timestampdiff\" => {\n                    parser.next_token(); // skip timestampdiff\n                    parser.expect_token(&Token::LParen)?;\n                    let time_unit = parser.next_token();\n                    parser.expect_token(&Token::Comma)?;\n                    let expr1 = parser.parse_expr()?;\n                    parser.expect_token(&Token::Comma)?;\n                    let expr2 = parser.parse_expr()?;\n                    parser.expect_token(&Token::RParen)?;\n\n                    // convert to function args\n                    let args = vec![\n                        FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(\n                            Value::SingleQuotedString(time_unit.to_string()),\n                        ))),\n                        FunctionArg::Unnamed(FunctionArgExpr::Expr(expr1)),\n                        FunctionArg::Unnamed(FunctionArgExpr::Expr(expr2)),\n                    ];\n\n                    Ok(Some(Expr::Function(Function {\n                        name: ObjectName(vec![Ident::new(\"timestampdiff\")]),\n                        args,\n                        over: None,\n                        distinct: false,\n                        special: false,\n                        order_by: vec![],\n                    })))\n                }\n                Token::Word(w) if w.value.to_lowercase() == \"to_timestamp\" => {\n                    // TO_TIMESTAMP(d, \"%d/%m/%Y\")\n                    parser.next_token(); // skip to_timestamp\n                    parser.expect_token(&Token::LParen)?;\n                    let expr = parser.parse_expr()?;\n                    let comma = parser.consume_token(&Token::Comma);\n                    let time_format = if comma {\n                        parser.next_token().to_string()\n                    } else {\n                        \"%Y-%m-%d %H:%M:%S\".to_string()\n                    };\n                    parser.expect_token(&Token::RParen)?;\n\n                    // convert to function args\n                    let args = vec![\n                        FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)),\n                        FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(\n                            Value::SingleQuotedString(time_format),\n                        ))),\n                    ];\n\n                    Ok(Some(Expr::Function(Function {\n                        name: ObjectName(vec![Ident::new(\"dsql_totimestamp\")]),\n                        args,\n                        over: None,\n                        distinct: false,\n                        special: false,\n                        order_by: vec![],\n                    })))\n                }\n                Token::Word(w) if w.value.to_lowercase() == \"extract\" => {\n                    // EXTRACT(DATE FROM d)\n                    parser.next_token(); // skip extract\n                    parser.expect_token(&Token::LParen)?;\n                    if !parser.parse_keywords(&[Keyword::DATE, Keyword::FROM]) {\n                        // Parse EXTRACT(x FROM d) as normal\n                        parser.prev_token();\n                        parser.prev_token();\n                        return Ok(None);\n                    }\n                    let expr = parser.parse_expr()?;\n                    parser.expect_token(&Token::RParen)?;\n\n                    // convert to function args\n                    let args = vec![\n                        FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(\n                            Value::SingleQuotedString(\"DATE\".to_string()),\n                        ))),\n                        FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)),\n                    ];\n\n                    Ok(Some(Expr::Function(Function {\n                        name: ObjectName(vec![Ident::new(\"extract_date\")]),\n                        args,\n                        over: None,\n                        distinct: false,\n                        special: false,\n                        order_by: vec![],\n                    })))\n                }\n                _ => Ok(None),\n            }\n        }\n        match parse_expr(parser) {\n            Ok(Some(expr)) => Some(Ok(expr)),\n            Ok(None) => None,\n            Err(e) => Some(Err(e)),\n        }\n    }\n}\n"
  },
  {
    "path": "src/error.rs",
    "content": "use std::fmt::{Display, Formatter};\n\nuse datafusion_python::{\n    datafusion_common::DataFusionError,\n    datafusion_sql::sqlparser::{parser::ParserError, tokenizer::TokenizerError},\n};\nuse pyo3::PyErr;\n\npub type Result<T> = std::result::Result<T, DaskPlannerError>;\n\n#[derive(Debug)]\npub enum DaskPlannerError {\n    DataFusionError(DataFusionError),\n    ParserError(ParserError),\n    TokenizerError(TokenizerError),\n    Internal(String),\n    InvalidIOFilter(String),\n}\n\nimpl Display for DaskPlannerError {\n    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {\n        match self {\n            Self::DataFusionError(e) => write!(f, \"DataFusion Error: {e}\"),\n            Self::ParserError(e) => write!(f, \"SQL Parser Error: {e}\"),\n            Self::TokenizerError(e) => write!(f, \"SQL Tokenizer Error: {e}\"),\n            Self::Internal(e) => write!(f, \"Internal Error: {e}\"),\n            Self::InvalidIOFilter(e) => write!(f, \"Invalid pyarrow filter: {e} encountered. Defaulting to Dask CPU/GPU bound task operation\"),\n        }\n    }\n}\n\nimpl From<TokenizerError> for DaskPlannerError {\n    fn from(err: TokenizerError) -> Self {\n        Self::TokenizerError(err)\n    }\n}\n\nimpl From<ParserError> for DaskPlannerError {\n    fn from(err: ParserError) -> Self {\n        Self::ParserError(err)\n    }\n}\n\nimpl From<DataFusionError> for DaskPlannerError {\n    fn from(err: DataFusionError) -> Self {\n        Self::DataFusionError(err)\n    }\n}\n\nimpl From<DaskPlannerError> for PyErr {\n    fn from(err: DaskPlannerError) -> PyErr {\n        PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(\"{err:?}\"))\n    }\n}\n"
  },
  {
    "path": "src/expression.rs",
    "content": "use std::{borrow::Cow, convert::From, sync::Arc};\n\nuse datafusion_python::{\n    datafusion::arrow::datatypes::DataType,\n    datafusion_common::{Column, DFField, DFSchema, ScalarValue},\n    datafusion_expr::{\n        expr::{\n            AggregateFunction,\n            AggregateUDF,\n            Alias,\n            BinaryExpr,\n            Cast,\n            Exists,\n            InList,\n            InSubquery,\n            ScalarFunction,\n            ScalarUDF,\n            Sort,\n            TryCast,\n            WindowFunction,\n        },\n        lit,\n        utils::exprlist_to_fields,\n        Between,\n        BuiltinScalarFunction,\n        Case,\n        Expr,\n        GetIndexedField,\n        Like,\n        LogicalPlan,\n        Operator,\n    },\n    datafusion_sql::TableReference,\n};\nuse pyo3::prelude::*;\n\nuse crate::{\n    error::{DaskPlannerError, Result},\n    sql::{\n        exceptions::{py_runtime_err, py_type_err},\n        logical,\n        types::RexType,\n    },\n};\n\n/// An PyExpr that can be used on a DataFrame\n#[pyclass(name = \"Expression\", module = \"dask_sql\", subclass)]\n#[derive(Debug, Clone)]\npub struct PyExpr {\n    pub expr: Expr,\n    // Why a Vec here? Because BinaryExpr on Join might have multiple LogicalPlans\n    pub input_plan: Option<Vec<Arc<LogicalPlan>>>,\n}\n\nimpl From<PyExpr> for Expr {\n    fn from(expr: PyExpr) -> Expr {\n        expr.expr\n    }\n}\n\n#[pyclass(name = \"ScalarValue\", module = \"dask_sql\", subclass)]\n#[derive(Debug, Clone)]\npub struct PyScalarValue {\n    pub scalar_value: ScalarValue,\n}\n\nimpl From<PyScalarValue> for ScalarValue {\n    fn from(pyscalar: PyScalarValue) -> ScalarValue {\n        pyscalar.scalar_value\n    }\n}\n\nimpl From<ScalarValue> for PyScalarValue {\n    fn from(scalar_value: ScalarValue) -> PyScalarValue {\n        PyScalarValue { scalar_value }\n    }\n}\n\n/// Convert a list of DataFusion Expr to PyExpr\npub fn py_expr_list(input: &Arc<LogicalPlan>, expr: &[Expr]) -> PyResult<Vec<PyExpr>> {\n    Ok(expr\n        .iter()\n        .map(|e| PyExpr::from(e.clone(), Some(vec![input.clone()])))\n        .collect())\n}\n\nimpl PyExpr {\n    /// Generally we would implement the `From` trait offered by Rust\n    /// However in this case Expr does not contain the contextual\n    /// `LogicalPlan` instance that we need so we need to make a instance\n    /// function to take and create the PyExpr.\n    pub fn from(expr: Expr, input: Option<Vec<Arc<LogicalPlan>>>) -> PyExpr {\n        PyExpr {\n            input_plan: input,\n            expr,\n        }\n    }\n\n    /// Determines the name of the `Expr` instance by examining the LogicalPlan\n    pub fn _column_name(&self, plan: &LogicalPlan) -> Result<String> {\n        let field = expr_to_field(&self.expr, plan)?;\n        Ok(field.qualified_column().flat_name())\n    }\n\n    fn _rex_type(&self, expr: &Expr) -> RexType {\n        match expr {\n            Expr::Alias(..) => RexType::Alias,\n            Expr::Column(..)\n            | Expr::QualifiedWildcard { .. }\n            | Expr::GetIndexedField { .. }\n            | Expr::Wildcard => RexType::Reference,\n            Expr::ScalarVariable(..) | Expr::Literal(..) => RexType::Literal,\n            Expr::BinaryExpr { .. }\n            | Expr::Not(..)\n            | Expr::IsNotNull(..)\n            | Expr::Negative(..)\n            | Expr::IsNull(..)\n            | Expr::Like { .. }\n            | Expr::SimilarTo { .. }\n            | Expr::Between { .. }\n            | Expr::Case { .. }\n            | Expr::Cast { .. }\n            | Expr::TryCast { .. }\n            | Expr::Sort { .. }\n            | Expr::ScalarFunction { .. }\n            | Expr::AggregateFunction { .. }\n            | Expr::WindowFunction { .. }\n            | Expr::AggregateUDF { .. }\n            | Expr::InList { .. }\n            | Expr::ScalarUDF { .. }\n            | Expr::Exists { .. }\n            | Expr::InSubquery { .. }\n            | Expr::GroupingSet(..)\n            | Expr::IsTrue(..)\n            | Expr::IsFalse(..)\n            | Expr::IsUnknown(_)\n            | Expr::IsNotTrue(..)\n            | Expr::IsNotFalse(..)\n            | Expr::Placeholder { .. }\n            | Expr::OuterReferenceColumn(_, _)\n            | Expr::IsNotUnknown(_) => RexType::Call,\n            Expr::ScalarSubquery(..) => RexType::ScalarSubquery,\n        }\n    }\n}\n\nmacro_rules! extract_scalar_value {\n    ($self: expr, $variant: ident) => {\n        match $self.get_scalar_value()? {\n            ScalarValue::$variant(value) => Ok(*value),\n            other => Err(unexpected_literal_value(other)),\n        }\n    };\n}\n\n#[pymethods]\nimpl PyExpr {\n    #[staticmethod]\n    pub fn literal(value: PyScalarValue) -> PyExpr {\n        PyExpr::from(lit(value.scalar_value), None)\n    }\n\n    /// Extracts the LogicalPlan from a Subquery, or supported Subquery sub-type, from\n    /// the expression instance\n    #[pyo3(name = \"getSubqueryLogicalPlan\")]\n    pub fn subquery_plan(&self) -> PyResult<logical::PyLogicalPlan> {\n        match &self.expr {\n            Expr::ScalarSubquery(subquery) => Ok(subquery.subquery.as_ref().clone().into()),\n            Expr::InSubquery(insubquery) => {\n                Ok(insubquery.subquery.subquery.as_ref().clone().into())\n            }\n            _ => Err(py_type_err(format!(\n                \"Attempted to extract a LogicalPlan instance from invalid Expr {:?}.\n                Only Subquery and related variants are supported for this operation.\",\n                &self.expr\n            ))),\n        }\n    }\n\n    /// If this Expression instances references an existing\n    /// Column in the SQL parse tree or not\n    #[pyo3(name = \"isInputReference\")]\n    pub fn is_input_reference(&self) -> PyResult<bool> {\n        Ok(matches!(&self.expr, Expr::Column(_col)))\n    }\n\n    #[pyo3(name = \"toString\")]\n    pub fn to_string(&self) -> PyResult<String> {\n        Ok(format!(\"{}\", &self.expr))\n    }\n\n    /// Gets the positional index of the Expr instance from the LogicalPlan DFSchema\n    #[pyo3(name = \"getIndex\")]\n    pub fn index(&self) -> PyResult<usize> {\n        let input: &Option<Vec<Arc<LogicalPlan>>> = &self.input_plan;\n        match input {\n            Some(input_plans) if !input_plans.is_empty() => {\n                let mut schema: DFSchema = (**input_plans[0].schema()).clone();\n                for plan in input_plans.iter().skip(1) {\n                    schema.merge(plan.schema().as_ref());\n                }\n                let name = get_expr_name(&self.expr).map_err(py_runtime_err)?;\n                if name != \"*\" {\n                    schema\n                        .index_of_column(&Column::from_qualified_name(name.clone()))\n                        .or_else(|_| {\n                            // Handles cases when from_qualified_name doesn't format the Column correctly.\n                            // \"name\" will always contain the name of the column. Anything in addition to\n                            // that will be separated by a '.' and should be further referenced.\n                            match &self.expr {\n                                Expr::Column(col) => {\n                                    schema.index_of_column(col).map_err(py_runtime_err)\n                                }\n                                _ => {\n                                    let parts = name.split('.').collect::<Vec<&str>>();\n                                    let tbl_reference = match parts.len() {\n                                        // Single element means name contains just the column name so no TableReference\n                                        1 => None,\n                                        // Tablename.column_name\n                                        2 => Some(\n                                            TableReference::Bare {\n                                                table: Cow::Borrowed(parts[0]),\n                                            }\n                                            .to_owned_reference(),\n                                        ),\n                                        // Schema_name.table_name.column_name\n                                        3 => Some(\n                                            TableReference::Partial {\n                                                schema: Cow::Borrowed(parts[0]),\n                                                table: Cow::Borrowed(parts[1]),\n                                            }\n                                            .to_owned_reference(),\n                                        ),\n                                        // catalog_name.schema_name.table_name.column_name\n                                        4 => Some(\n                                            TableReference::Full {\n                                                catalog: Cow::Borrowed(parts[0]),\n                                                schema: Cow::Borrowed(parts[1]),\n                                                table: Cow::Borrowed(parts[2]),\n                                            }\n                                            .to_owned_reference(),\n                                        ),\n                                        _ => None,\n                                    };\n\n                                    let col = Column {\n                                        relation: tbl_reference.clone(),\n                                        name: parts[parts.len() - 1].to_string(),\n                                    };\n                                    schema.index_of_column(&col).map_err(py_runtime_err)\n                                }\n                            }\n                        })\n                } else {\n                    // Since this is wildcard any Column will do, just use first one\n                    Ok(0)\n                }\n            }\n            _ => Err(py_runtime_err(\n                \"We need a valid LogicalPlan instance to get the Expr's index in the schema\",\n            )),\n        }\n    }\n\n    /// Examine the current/\"self\" PyExpr and return its \"type\"\n    /// In this context a \"type\" is what Dask-SQL Python\n    /// RexConverter plugin instance should be invoked to handle\n    /// the Rex conversion\n    #[pyo3(name = \"getExprType\")]\n    pub fn get_expr_type(&self) -> PyResult<String> {\n        Ok(String::from(match &self.expr {\n            Expr::Alias(..)\n            | Expr::Column(..)\n            | Expr::Literal(..)\n            | Expr::BinaryExpr { .. }\n            | Expr::Between { .. }\n            | Expr::Cast { .. }\n            | Expr::Sort { .. }\n            | Expr::ScalarFunction { .. }\n            | Expr::AggregateFunction { .. }\n            | Expr::InList { .. }\n            | Expr::InSubquery { .. }\n            | Expr::ScalarUDF { .. }\n            | Expr::AggregateUDF { .. }\n            | Expr::Exists { .. }\n            | Expr::ScalarSubquery(..)\n            | Expr::QualifiedWildcard { .. }\n            | Expr::Not(..)\n            | Expr::OuterReferenceColumn(_, _)\n            | Expr::GroupingSet(..) => self.expr.variant_name(),\n            Expr::ScalarVariable(..)\n            | Expr::IsNotNull(..)\n            | Expr::Negative(..)\n            | Expr::GetIndexedField { .. }\n            | Expr::IsNull(..)\n            | Expr::IsTrue(_)\n            | Expr::IsFalse(_)\n            | Expr::IsUnknown(_)\n            | Expr::IsNotTrue(_)\n            | Expr::IsNotFalse(_)\n            | Expr::Like { .. }\n            | Expr::SimilarTo { .. }\n            | Expr::IsNotUnknown(_)\n            | Expr::Case { .. }\n            | Expr::TryCast { .. }\n            | Expr::WindowFunction { .. }\n            | Expr::Placeholder { .. }\n            | Expr::Wildcard => {\n                return Err(py_type_err(format!(\n                    \"Encountered unsupported expression type: {}\",\n                    &self.expr.variant_name()\n                )))\n            }\n        }))\n    }\n\n    /// Determines the type of this Expr based on its variant\n    #[pyo3(name = \"getRexType\")]\n    pub fn rex_type(&self) -> PyResult<RexType> {\n        Ok(self._rex_type(&self.expr))\n    }\n\n    /// Python friendly shim code to get the name of a column referenced by an expression\n    pub fn column_name(&self, mut plan: logical::PyLogicalPlan) -> PyResult<String> {\n        self._column_name(&plan.current_node())\n            .map_err(py_runtime_err)\n    }\n\n    /// Row expressions, Rex(s), operate on the concept of operands. This maps to expressions that are used in\n    /// the \"call\" logic of the Dask-SQL python codebase. Different variants of Expressions, Expr(s),\n    /// store those operands in different datastructures. This function examines the Expr variant and returns\n    /// the operands to the calling logic as a Vec of PyExpr instances.\n    #[pyo3(name = \"getOperands\")]\n    pub fn get_operands(&self) -> PyResult<Vec<PyExpr>> {\n        match &self.expr {\n            // Expr variants that are themselves the operand to return\n            Expr::Column(..) | Expr::ScalarVariable(..) | Expr::Literal(..) => {\n                Ok(vec![PyExpr::from(\n                    self.expr.clone(),\n                    self.input_plan.clone(),\n                )])\n            }\n\n            // Expr(s) that house the Expr instance to return in their bounded params\n            Expr::Not(expr)\n            | Expr::IsNull(expr)\n            | Expr::IsNotNull(expr)\n            | Expr::IsTrue(expr)\n            | Expr::IsFalse(expr)\n            | Expr::IsUnknown(expr)\n            | Expr::IsNotTrue(expr)\n            | Expr::IsNotFalse(expr)\n            | Expr::IsNotUnknown(expr)\n            | Expr::Negative(expr)\n            | Expr::GetIndexedField(GetIndexedField { expr, .. })\n            | Expr::Cast(Cast { expr, .. })\n            | Expr::TryCast(TryCast { expr, .. })\n            | Expr::Sort(Sort { expr, .. })\n            | Expr::InSubquery(InSubquery { expr, .. }) => {\n                Ok(vec![PyExpr::from(*expr.clone(), self.input_plan.clone())])\n            }\n\n            // Expr variants containing a collection of Expr(s) for operands\n            Expr::AggregateFunction(AggregateFunction { args, .. })\n            | Expr::AggregateUDF(AggregateUDF { args, .. })\n            | Expr::ScalarFunction(ScalarFunction { args, .. })\n            | Expr::ScalarUDF(ScalarUDF { args, .. })\n            | Expr::WindowFunction(WindowFunction { args, .. }) => Ok(args\n                .iter()\n                .map(|arg| PyExpr::from(arg.clone(), self.input_plan.clone()))\n                .collect()),\n\n            // Expr(s) that require more specific processing\n            Expr::Case(Case {\n                expr,\n                when_then_expr,\n                else_expr,\n            }) => {\n                let mut operands: Vec<PyExpr> = Vec::new();\n\n                if let Some(e) = expr {\n                    for (when, then) in when_then_expr {\n                        operands.push(PyExpr::from(\n                            Expr::BinaryExpr(BinaryExpr::new(\n                                Box::new(*e.clone()),\n                                Operator::Eq,\n                                Box::new(*when.clone()),\n                            )),\n                            self.input_plan.clone(),\n                        ));\n                        operands.push(PyExpr::from(*then.clone(), self.input_plan.clone()));\n                    }\n                } else {\n                    for (when, then) in when_then_expr {\n                        operands.push(PyExpr::from(*when.clone(), self.input_plan.clone()));\n                        operands.push(PyExpr::from(*then.clone(), self.input_plan.clone()));\n                    }\n                };\n\n                if let Some(e) = else_expr {\n                    operands.push(PyExpr::from(*e.clone(), self.input_plan.clone()));\n                };\n\n                Ok(operands)\n            }\n            Expr::Alias(Alias { expr, .. }) => {\n                Ok(vec![PyExpr::from(*expr.clone(), self.input_plan.clone())])\n            }\n            Expr::InList(InList { expr, list, .. }) => {\n                let mut operands: Vec<PyExpr> =\n                    vec![PyExpr::from(*expr.clone(), self.input_plan.clone())];\n                for list_elem in list {\n                    operands.push(PyExpr::from(list_elem.clone(), self.input_plan.clone()));\n                }\n\n                Ok(operands)\n            }\n            Expr::BinaryExpr(BinaryExpr { left, right, .. }) => Ok(vec![\n                PyExpr::from(*left.clone(), self.input_plan.clone()),\n                PyExpr::from(*right.clone(), self.input_plan.clone()),\n            ]),\n            Expr::Like(Like { expr, pattern, .. }) => Ok(vec![\n                PyExpr::from(*expr.clone(), self.input_plan.clone()),\n                PyExpr::from(*pattern.clone(), self.input_plan.clone()),\n            ]),\n            Expr::SimilarTo(Like { expr, pattern, .. }) => Ok(vec![\n                PyExpr::from(*expr.clone(), self.input_plan.clone()),\n                PyExpr::from(*pattern.clone(), self.input_plan.clone()),\n            ]),\n            Expr::Between(Between {\n                expr,\n                negated: _,\n                low,\n                high,\n            }) => Ok(vec![\n                PyExpr::from(*expr.clone(), self.input_plan.clone()),\n                PyExpr::from(*low.clone(), self.input_plan.clone()),\n                PyExpr::from(*high.clone(), self.input_plan.clone()),\n            ]),\n            Expr::Wildcard => Ok(vec![PyExpr::from(\n                self.expr.clone(),\n                self.input_plan.clone(),\n            )]),\n\n            // Currently un-support/implemented Expr types for Rex Call operations\n            Expr::GroupingSet(..)\n            | Expr::OuterReferenceColumn(_, _)\n            | Expr::QualifiedWildcard { .. }\n            | Expr::ScalarSubquery(..)\n            | Expr::Placeholder { .. }\n            | Expr::Exists { .. } => Err(py_runtime_err(format!(\n                \"Unimplemented Expr type: {}\",\n                self.expr\n            ))),\n        }\n    }\n\n    #[pyo3(name = \"getOperatorName\")]\n    pub fn get_operator_name(&self) -> PyResult<String> {\n        Ok(match &self.expr {\n            Expr::BinaryExpr(BinaryExpr {\n                left: _,\n                op,\n                right: _,\n            }) => format!(\"{op}\"),\n            Expr::ScalarFunction(ScalarFunction { fun, args: _ }) => format!(\"{fun}\"),\n            Expr::ScalarUDF(ScalarUDF { fun, .. }) => fun.name.clone(),\n            Expr::Cast { .. } => \"cast\".to_string(),\n            Expr::Between { .. } => \"between\".to_string(),\n            Expr::Case { .. } => \"case\".to_string(),\n            Expr::IsNull(..) => \"is null\".to_string(),\n            Expr::IsNotNull(..) => \"is not null\".to_string(),\n            Expr::IsTrue(_) => \"is true\".to_string(),\n            Expr::IsFalse(_) => \"is false\".to_string(),\n            Expr::IsUnknown(_) => \"is unknown\".to_string(),\n            Expr::IsNotTrue(_) => \"is not true\".to_string(),\n            Expr::IsNotFalse(_) => \"is not false\".to_string(),\n            Expr::IsNotUnknown(_) => \"is not unknown\".to_string(),\n            Expr::InList { .. } => \"in list\".to_string(),\n            Expr::InSubquery(..) => \"in subquery\".to_string(),\n            Expr::Negative(..) => \"negative\".to_string(),\n            Expr::Not(..) => \"not\".to_string(),\n            Expr::Like(Like {\n                negated,\n                case_insensitive,\n                ..\n            }) => {\n                format!(\n                    \"{}{}like\",\n                    if *negated { \"not \" } else { \"\" },\n                    if *case_insensitive { \"i\" } else { \"\" }\n                )\n            }\n            Expr::SimilarTo(Like { negated, .. }) => {\n                if *negated {\n                    \"not similar to\".to_string()\n                } else {\n                    \"similar to\".to_string()\n                }\n            }\n            _ => {\n                return Err(py_type_err(format!(\n                    \"Catch all triggered in get_operator_name: {:?}\",\n                    &self.expr\n                )))\n            }\n        })\n    }\n\n    /// Gets the ScalarValue represented by the Expression\n    #[pyo3(name = \"getType\")]\n    pub fn get_type(&self) -> PyResult<String> {\n        Ok(String::from(match &self.expr {\n            Expr::BinaryExpr(BinaryExpr {\n                left: _,\n                op,\n                right: _,\n            }) => match op {\n                Operator::Eq\n                | Operator::NotEq\n                | Operator::Lt\n                | Operator::LtEq\n                | Operator::Gt\n                | Operator::GtEq\n                | Operator::And\n                | Operator::Or\n                | Operator::IsDistinctFrom\n                | Operator::IsNotDistinctFrom\n                | Operator::RegexMatch\n                | Operator::RegexIMatch\n                | Operator::RegexNotMatch\n                | Operator::RegexNotIMatch => \"BOOLEAN\",\n                Operator::Plus | Operator::Minus | Operator::Multiply | Operator::Modulo => {\n                    \"BIGINT\"\n                }\n                Operator::Divide => \"FLOAT\",\n                Operator::StringConcat => \"VARCHAR\",\n                Operator::BitwiseShiftLeft\n                | Operator::BitwiseShiftRight\n                | Operator::BitwiseXor\n                | Operator::BitwiseAnd\n                | Operator::BitwiseOr => {\n                    // the type here should be the same as the type of the left expression\n                    // but we can only compute that if we have the schema available\n                    return Err(py_type_err(\n                        \"Bitwise operators unsupported in get_type\".to_string(),\n                    ));\n                }\n                Operator::AtArrow | Operator::ArrowAt => {\n                    todo!()\n                }\n            },\n            Expr::Literal(scalar_value) => match scalar_value {\n                ScalarValue::Boolean(_value) => \"Boolean\",\n                ScalarValue::Float32(_value) => \"Float32\",\n                ScalarValue::Float64(_value) => \"Float64\",\n                ScalarValue::Decimal128(_value, ..) => \"Decimal128\",\n                ScalarValue::Decimal256(_, _, _) => \"Decimal256\",\n                ScalarValue::Dictionary(..) => \"Dictionary\",\n                ScalarValue::Int8(_value) => \"Int8\",\n                ScalarValue::Int16(_value) => \"Int16\",\n                ScalarValue::Int32(_value) => \"Int32\",\n                ScalarValue::Int64(_value) => \"Int64\",\n                ScalarValue::UInt8(_value) => \"UInt8\",\n                ScalarValue::UInt16(_value) => \"UInt16\",\n                ScalarValue::UInt32(_value) => \"UInt32\",\n                ScalarValue::UInt64(_value) => \"UInt64\",\n                ScalarValue::Utf8(_value) => \"Utf8\",\n                ScalarValue::LargeUtf8(_value) => \"LargeUtf8\",\n                ScalarValue::Binary(_value) => \"Binary\",\n                ScalarValue::LargeBinary(_value) => \"LargeBinary\",\n                ScalarValue::Date32(_value) => \"Date32\",\n                ScalarValue::Date64(_value) => \"Date64\",\n                ScalarValue::Time32Second(_value) => \"Time32\",\n                ScalarValue::Time32Millisecond(_value) => \"Time32\",\n                ScalarValue::Time64Microsecond(_value) => \"Time64\",\n                ScalarValue::Time64Nanosecond(_value) => \"Time64\",\n                ScalarValue::Null => \"Null\",\n                ScalarValue::TimestampSecond(..) => \"TimestampSecond\",\n                ScalarValue::TimestampMillisecond(..) => \"TimestampMillisecond\",\n                ScalarValue::TimestampMicrosecond(..) => \"TimestampMicrosecond\",\n                ScalarValue::TimestampNanosecond(..) => \"TimestampNanosecond\",\n                ScalarValue::IntervalYearMonth(..) => \"IntervalYearMonth\",\n                ScalarValue::IntervalDayTime(..) => \"IntervalDayTime\",\n                ScalarValue::IntervalMonthDayNano(..) => \"IntervalMonthDayNano\",\n                ScalarValue::List(..) => \"List\",\n                ScalarValue::Struct(..) => \"Struct\",\n                ScalarValue::FixedSizeBinary(_, _) => \"FixedSizeBinary\",\n                ScalarValue::Fixedsizelist(..) => \"Fixedsizelist\",\n                ScalarValue::DurationSecond(..) => \"DurationSecond\",\n                ScalarValue::DurationMillisecond(..) => \"DurationMillisecond\",\n                ScalarValue::DurationMicrosecond(..) => \"DurationMicrosecond\",\n                ScalarValue::DurationNanosecond(..) => \"DurationNanosecond\",\n            },\n            Expr::ScalarFunction(ScalarFunction { fun, args: _ }) => match fun {\n                BuiltinScalarFunction::Abs => \"Abs\",\n                BuiltinScalarFunction::DatePart => \"DatePart\",\n                _ => {\n                    return Err(py_type_err(format!(\n                        \"Catch all triggered for ScalarFunction in get_type; {fun:?}\"\n                    )))\n                }\n            },\n            Expr::Cast(Cast { expr: _, data_type }) => match data_type {\n                DataType::Null => \"NULL\",\n                DataType::Boolean => \"BOOLEAN\",\n                DataType::Int8 | DataType::UInt8 => \"TINYINT\",\n                DataType::Int16 | DataType::UInt16 => \"SMALLINT\",\n                DataType::Int32 | DataType::UInt32 => \"INTEGER\",\n                DataType::Int64 | DataType::UInt64 => \"BIGINT\",\n                DataType::Float32 => \"FLOAT\",\n                DataType::Float64 => \"DOUBLE\",\n                DataType::Timestamp { .. } => \"TIMESTAMP\",\n                DataType::Date32 | DataType::Date64 => \"DATE\",\n                DataType::Time32(..) => \"TIME32\",\n                DataType::Time64(..) => \"TIME64\",\n                DataType::Duration(..) => \"DURATION\",\n                DataType::Interval(..) => \"INTERVAL\",\n                DataType::Binary => \"BINARY\",\n                DataType::FixedSizeBinary(..) => \"FIXEDSIZEBINARY\",\n                DataType::LargeBinary => \"LARGEBINARY\",\n                DataType::Utf8 => \"VARCHAR\",\n                DataType::LargeUtf8 => \"BIGVARCHAR\",\n                DataType::List(..) => \"LIST\",\n                DataType::FixedSizeList(..) => \"FIXEDSIZELIST\",\n                DataType::LargeList(..) => \"LARGELIST\",\n                DataType::Struct(..) => \"STRUCT\",\n                DataType::Union(..) => \"UNION\",\n                DataType::Dictionary(..) => \"DICTIONARY\",\n                DataType::Decimal128(..) => \"DECIMAL\",\n                DataType::Decimal256(..) => \"DECIMAL\",\n                DataType::Map(..) => \"MAP\",\n                _ => {\n                    return Err(py_type_err(format!(\n                        \"Catch all triggered for Cast in get_type; {data_type:?}\"\n                    )))\n                }\n            },\n            _ => {\n                return Err(py_type_err(format!(\n                    \"Catch all triggered in get_type; {:?}\",\n                    &self.expr\n                )))\n            }\n        }))\n    }\n\n    /// Gets the precision/scale represented by the Expression's decimal datatype\n    #[pyo3(name = \"getPrecisionScale\")]\n    pub fn get_precision_scale(&self) -> PyResult<(u8, i8)> {\n        Ok(match &self.expr {\n            Expr::Cast(Cast { expr: _, data_type }) => match data_type {\n                DataType::Decimal128(precision, scale) | DataType::Decimal256(precision, scale) => {\n                    (*precision, *scale)\n                }\n                _ => {\n                    return Err(py_type_err(format!(\n                        \"Catch all triggered for Cast in get_precision_scale; {data_type:?}\"\n                    )))\n                }\n            },\n            _ => {\n                return Err(py_type_err(format!(\n                    \"Catch all triggered in get_precision_scale; {:?}\",\n                    &self.expr\n                )))\n            }\n        })\n    }\n\n    #[pyo3(name = \"getFilterExpr\")]\n    pub fn get_filter_expr(&self) -> PyResult<Option<PyExpr>> {\n        // TODO refactor to avoid duplication\n        match &self.expr {\n            Expr::Alias(Alias { expr, .. }) => match expr.as_ref() {\n                Expr::AggregateFunction(AggregateFunction { filter, .. })\n                | Expr::AggregateUDF(AggregateUDF { filter, .. }) => match filter {\n                    Some(filter) => {\n                        Ok(Some(PyExpr::from(*filter.clone(), self.input_plan.clone())))\n                    }\n                    None => Ok(None),\n                },\n                _ => Err(py_type_err(\n                    \"getFilterExpr() - Non-aggregate expression encountered\",\n                )),\n            },\n            Expr::AggregateFunction(AggregateFunction { filter, .. })\n            | Expr::AggregateUDF(AggregateUDF { filter, .. }) => match filter {\n                Some(filter) => Ok(Some(PyExpr::from(*filter.clone(), self.input_plan.clone()))),\n                None => Ok(None),\n            },\n            _ => Err(py_type_err(\n                \"getFilterExpr() - Non-aggregate expression encountered\",\n            )),\n        }\n    }\n\n    #[pyo3(name = \"getFloat32Value\")]\n    pub fn float_32_value(&self) -> PyResult<Option<f32>> {\n        extract_scalar_value!(self, Float32)\n    }\n\n    #[pyo3(name = \"getFloat64Value\")]\n    pub fn float_64_value(&self) -> PyResult<Option<f64>> {\n        extract_scalar_value!(self, Float64)\n    }\n\n    #[pyo3(name = \"getDecimal128Value\")]\n    pub fn decimal_128_value(&mut self) -> PyResult<(Option<i128>, u8, i8)> {\n        match self.get_scalar_value()? {\n            ScalarValue::Decimal128(value, precision, scale) => Ok((*value, *precision, *scale)),\n            other => Err(unexpected_literal_value(other)),\n        }\n    }\n\n    #[pyo3(name = \"getInt8Value\")]\n    pub fn int_8_value(&self) -> PyResult<Option<i8>> {\n        extract_scalar_value!(self, Int8)\n    }\n\n    #[pyo3(name = \"getInt16Value\")]\n    pub fn int_16_value(&self) -> PyResult<Option<i16>> {\n        extract_scalar_value!(self, Int16)\n    }\n\n    #[pyo3(name = \"getInt32Value\")]\n    pub fn int_32_value(&self) -> PyResult<Option<i32>> {\n        extract_scalar_value!(self, Int32)\n    }\n\n    #[pyo3(name = \"getInt64Value\")]\n    pub fn int_64_value(&self) -> PyResult<Option<i64>> {\n        extract_scalar_value!(self, Int64)\n    }\n\n    #[pyo3(name = \"getUInt8Value\")]\n    pub fn uint_8_value(&self) -> PyResult<Option<u8>> {\n        extract_scalar_value!(self, UInt8)\n    }\n\n    #[pyo3(name = \"getUInt16Value\")]\n    pub fn uint_16_value(&self) -> PyResult<Option<u16>> {\n        extract_scalar_value!(self, UInt16)\n    }\n\n    #[pyo3(name = \"getUInt32Value\")]\n    pub fn uint_32_value(&self) -> PyResult<Option<u32>> {\n        extract_scalar_value!(self, UInt32)\n    }\n\n    #[pyo3(name = \"getUInt64Value\")]\n    pub fn uint_64_value(&self) -> PyResult<Option<u64>> {\n        extract_scalar_value!(self, UInt64)\n    }\n\n    #[pyo3(name = \"getDate32Value\")]\n    pub fn date_32_value(&self) -> PyResult<Option<i32>> {\n        extract_scalar_value!(self, Date32)\n    }\n\n    #[pyo3(name = \"getDate64Value\")]\n    pub fn date_64_value(&self) -> PyResult<Option<i64>> {\n        extract_scalar_value!(self, Date64)\n    }\n\n    #[pyo3(name = \"getTime64Value\")]\n    pub fn time_64_value(&self) -> PyResult<Option<i64>> {\n        extract_scalar_value!(self, Time64Nanosecond)\n    }\n\n    #[pyo3(name = \"getTimestampValue\")]\n    pub fn timestamp_value(&mut self) -> PyResult<(Option<i64>, Option<String>)> {\n        match self.get_scalar_value()? {\n            ScalarValue::TimestampNanosecond(iv, tz)\n            | ScalarValue::TimestampMicrosecond(iv, tz)\n            | ScalarValue::TimestampMillisecond(iv, tz)\n            | ScalarValue::TimestampSecond(iv, tz) => match tz {\n                Some(time_zone) => Ok((*iv, Some(time_zone.to_string()))),\n                None => Ok((*iv, None)),\n            },\n            other => Err(unexpected_literal_value(other)),\n        }\n    }\n\n    #[pyo3(name = \"getBoolValue\")]\n    pub fn bool_value(&self) -> PyResult<Option<bool>> {\n        extract_scalar_value!(self, Boolean)\n    }\n\n    #[pyo3(name = \"getStringValue\")]\n    pub fn string_value(&self) -> PyResult<Option<String>> {\n        match self.get_scalar_value()? {\n            ScalarValue::Utf8(value) => Ok(value.clone()),\n            other => Err(unexpected_literal_value(other)),\n        }\n    }\n\n    #[pyo3(name = \"getIntervalDayTimeValue\")]\n    pub fn interval_day_time_value(&self) -> PyResult<Option<(i32, i32)>> {\n        match self.get_scalar_value()? {\n            ScalarValue::IntervalDayTime(Some(iv)) => {\n                let interval = *iv as u64;\n                let days = (interval >> 32) as i32;\n                let ms = interval as i32;\n                Ok(Some((days, ms)))\n            }\n            ScalarValue::IntervalDayTime(None) => Ok(None),\n            other => Err(unexpected_literal_value(other)),\n        }\n    }\n\n    #[pyo3(name = \"getIntervalMonthDayNanoValue\")]\n    pub fn interval_month_day_nano_value(&self) -> PyResult<Option<(i32, i32, i64)>> {\n        match self.get_scalar_value()? {\n            ScalarValue::IntervalMonthDayNano(Some(iv)) => {\n                let interval = *iv as u128;\n                let months = (interval >> 32) as i32;\n                let days = (interval >> 64) as i32;\n                let ns = interval as i64;\n                Ok(Some((months, days, ns)))\n            }\n            ScalarValue::IntervalMonthDayNano(None) => Ok(None),\n            other => Err(unexpected_literal_value(other)),\n        }\n    }\n\n    #[pyo3(name = \"isNegated\")]\n    pub fn is_negated(&self) -> PyResult<bool> {\n        match &self.expr {\n            Expr::Between(Between { negated, .. })\n            | Expr::Exists(Exists { negated, .. })\n            | Expr::InList(InList { negated, .. })\n            | Expr::InSubquery(InSubquery { negated, .. }) => Ok(*negated),\n            _ => Err(py_type_err(format!(\n                \"unknown Expr type {:?} encountered\",\n                &self.expr\n            ))),\n        }\n    }\n\n    #[pyo3(name = \"isDistinctAgg\")]\n    pub fn is_distinct_aggregation(&self) -> PyResult<bool> {\n        // TODO refactor to avoid duplication\n        match &self.expr {\n            Expr::AggregateFunction(funct) => Ok(funct.distinct),\n            Expr::AggregateUDF { .. } => Ok(false),\n            Expr::Alias(Alias { expr, .. }) => match expr.as_ref() {\n                Expr::AggregateFunction(funct) => Ok(funct.distinct),\n                Expr::AggregateUDF { .. } => Ok(false),\n                _ => Err(py_type_err(\n                    \"isDistinctAgg() - Non-aggregate expression encountered\",\n                )),\n            },\n            _ => Err(py_type_err(\n                \"getFilterExpr() - Non-aggregate expression encountered\",\n            )),\n        }\n    }\n\n    /// Returns if a sort expressions is an ascending sort\n    #[pyo3(name = \"isSortAscending\")]\n    pub fn is_sort_ascending(&self) -> PyResult<bool> {\n        match &self.expr {\n            Expr::Sort(Sort { asc, .. }) => Ok(*asc),\n            _ => Err(py_type_err(format!(\n                \"Provided Expr {:?} is not a sort type\",\n                &self.expr\n            ))),\n        }\n    }\n\n    /// Returns if nulls should be placed first in a sort expression\n    #[pyo3(name = \"isSortNullsFirst\")]\n    pub fn is_sort_nulls_first(&self) -> PyResult<bool> {\n        match &self.expr {\n            Expr::Sort(Sort { nulls_first, .. }) => Ok(*nulls_first),\n            _ => Err(py_type_err(format!(\n                \"Provided Expr {:?} is not a sort type\",\n                &self.expr\n            ))),\n        }\n    }\n\n    /// Returns the escape char for like/ilike/similar to expr variants\n    #[pyo3(name = \"getEscapeChar\")]\n    pub fn get_escape_char(&self) -> PyResult<Option<char>> {\n        match &self.expr {\n            Expr::Like(Like { escape_char, .. }) | Expr::SimilarTo(Like { escape_char, .. }) => {\n                Ok(*escape_char)\n            }\n            _ => Err(py_type_err(format!(\n                \"Provided Expr {:?} not one of Like/ILike/SimilarTo\",\n                &self.expr\n            ))),\n        }\n    }\n}\n\nimpl PyExpr {\n    /// Get the scalar value represented by this literal expression, returning an error\n    /// if this is not a literal expression\n    fn get_scalar_value(&self) -> Result<&ScalarValue> {\n        match &self.expr {\n            Expr::Literal(v) => Ok(v),\n            _ => Err(DaskPlannerError::Internal(\n                \"get_scalar_value() called on non-literal expression\".to_string(),\n            )),\n        }\n    }\n}\n\nfn unexpected_literal_value(value: &ScalarValue) -> PyErr {\n    DaskPlannerError::Internal(format!(\"getValue<T>() - Unexpected value: {value}\")).into()\n}\n\nfn get_expr_name(expr: &Expr) -> Result<String> {\n    match expr {\n        Expr::Alias(Alias { expr, .. }) => get_expr_name(expr),\n        Expr::Wildcard => {\n            // 'Wildcard' means any and all columns. We get the first valid column name here\n            Ok(\"*\".to_owned())\n        }\n        _ => Ok(expr.canonical_name()),\n    }\n}\n\n/// Create a [DFField] representing an [Expr], given an input [LogicalPlan] to resolve against\npub fn expr_to_field(expr: &Expr, input_plan: &LogicalPlan) -> Result<DFField> {\n    match expr {\n        Expr::Sort(Sort { expr, .. }) => {\n            // DataFusion does not support create_name for sort expressions (since they never\n            // appear in projections) so we just delegate to the contained expression instead\n            expr_to_field(expr, input_plan)\n        }\n        Expr::Wildcard => {\n            // Any column will do. We use the first column to keep things consistent\n            Ok(input_plan.schema().field(0).clone())\n        }\n        Expr::InSubquery(insubquery) => expr_to_field(&insubquery.expr, input_plan),\n        _ => {\n            let fields =\n                exprlist_to_fields(&[expr.clone()], input_plan).map_err(DaskPlannerError::from)?;\n            Ok(fields[0].clone())\n        }\n    }\n}\n\n#[cfg(test)]\nmod test {\n    use datafusion_python::{\n        datafusion_common::{Column, ScalarValue},\n        datafusion_expr::Expr,\n    };\n\n    use crate::{error::Result, expression::PyExpr};\n\n    #[test]\n    fn get_value_u32() -> Result<()> {\n        test_get_value(ScalarValue::UInt32(None))?;\n        test_get_value(ScalarValue::UInt32(Some(123)))\n    }\n\n    #[test]\n    fn get_value_utf8() -> Result<()> {\n        test_get_value(ScalarValue::Utf8(None))?;\n        test_get_value(ScalarValue::Utf8(Some(\"hello\".to_string())))\n    }\n\n    #[test]\n    fn get_value_non_literal() -> Result<()> {\n        let expr = PyExpr::from(Expr::Column(Column::from_qualified_name(\"a.b\")), None);\n        let error = expr\n            .get_scalar_value()\n            .expect_err(\"cannot get scalar value from column\");\n        assert_eq!(\n            \"Internal(\\\"get_scalar_value() called on non-literal expression\\\")\",\n            &format!(\"{:?}\", error)\n        );\n        Ok(())\n    }\n\n    fn test_get_value(value: ScalarValue) -> Result<()> {\n        let expr = PyExpr::from(Expr::Literal(value.clone()), None);\n        assert_eq!(&value, expr.get_scalar_value()?);\n        Ok(())\n    }\n}\n"
  },
  {
    "path": "src/lib.rs",
    "content": "use log::debug;\nuse pyo3::prelude::*;\n\nmod dialect;\nmod error;\nmod expression;\nmod parser;\nmod sql;\n\n/// Low-level DataFusion internal package.\n///\n/// The higher-level public API is defined in pure python files under the\n/// dask_planner directory.\n#[pymodule]\nfn _datafusion_lib(py: Python, m: &PyModule) -> PyResult<()> {\n    // Initialize the global Python logger instance\n    pyo3_log::init();\n\n    // Register the python classes\n    m.add_class::<expression::PyExpr>()?;\n    m.add_class::<sql::DaskSQLContext>()?;\n    m.add_class::<sql::types::SqlTypeName>()?;\n    m.add_class::<sql::types::RexType>()?;\n    m.add_class::<sql::types::DaskTypeMap>()?;\n    m.add_class::<sql::types::rel_data_type::RelDataType>()?;\n    m.add_class::<sql::statement::PyStatement>()?;\n    m.add_class::<sql::schema::DaskSchema>()?;\n    m.add_class::<sql::table::DaskTable>()?;\n    m.add_class::<sql::function::DaskFunction>()?;\n    m.add_class::<sql::table::DaskStatistics>()?;\n    m.add_class::<sql::logical::PyLogicalPlan>()?;\n    m.add_class::<sql::DaskSQLOptimizerConfig>()?;\n\n    // Exceptions\n    m.add(\n        \"DFParsingException\",\n        py.get_type::<sql::exceptions::ParsingException>(),\n    )?;\n    m.add(\n        \"DFOptimizationException\",\n        py.get_type::<sql::exceptions::OptimizationException>(),\n    )?;\n\n    debug!(\"dask_sql native library loaded\");\n\n    Ok(())\n}\n"
  },
  {
    "path": "src/parser.rs",
    "content": "//! SQL Parser\n//!\n//! Declares a SQL parser based on sqlparser that handles custom formats that we need.\n\nuse std::collections::VecDeque;\n\nuse datafusion_python::datafusion_sql::sqlparser::{\n    ast::{Expr, Ident, SelectItem, Statement as SQLStatement, UnaryOperator, Value},\n    dialect::{keywords::Keyword, Dialect},\n    parser::{Parser, ParserError},\n    tokenizer::{Token, TokenWithLocation, Tokenizer},\n};\nuse pyo3::prelude::*;\n\nuse crate::{\n    dialect::DaskDialect,\n    sql::{exceptions::py_type_err, parser_utils::DaskParserUtils, types::SqlTypeName},\n};\n\nmacro_rules! parser_err {\n    ($MSG:expr) => {\n        Err(ParserError::ParserError($MSG.to_string()))\n    };\n}\n\n#[derive(Debug, Clone, PartialEq, Eq)]\npub enum CustomExpr {\n    Map(Vec<Expr>),\n    Multiset(Vec<Expr>),\n    Nested(Vec<(String, PySqlArg)>),\n}\n\n#[pyclass(name = \"SqlArg\", module = \"dask_sql\")]\n#[derive(Debug, Clone, PartialEq, Eq)]\npub struct PySqlArg {\n    expr: Option<Expr>,\n    custom: Option<CustomExpr>,\n}\n\nimpl PySqlArg {\n    pub fn new(expr: Option<Expr>, custom: Option<CustomExpr>) -> Self {\n        Self { expr, custom }\n    }\n\n    fn expected<T>(&self, expected: &str) -> PyResult<T> {\n        Err(match &self.custom {\n            Some(custom_expr) => {\n                py_type_err(format!(\"Expected {expected}, found: {custom_expr:?}\"))\n            }\n            None => match &self.expr {\n                Some(expr) => py_type_err(format!(\"Expected {expected}, found: {expr:?}\")),\n                None => py_type_err(\"PySqlArg must be either a standard or custom AST expression\"),\n            },\n        })\n    }\n}\n\n#[pymethods]\nimpl PySqlArg {\n    #[pyo3(name = \"isCollection\")]\n    pub fn is_collection(&self) -> PyResult<bool> {\n        Ok(match &self.custom {\n            Some(custom_expr) => !matches!(custom_expr, CustomExpr::Nested(_)),\n            None => match &self.expr {\n                Some(expr) => matches!(expr, Expr::Array(_)),\n                None => return self.expected(\"\"),\n            },\n        })\n    }\n\n    #[pyo3(name = \"isKwargs\")]\n    pub fn is_kwargs(&self) -> PyResult<bool> {\n        Ok(matches!(&self.custom, Some(CustomExpr::Nested(_))))\n    }\n\n    #[pyo3(name = \"getOperandList\")]\n    pub fn get_operand_list(&self) -> PyResult<Vec<PySqlArg>> {\n        Ok(match &self.custom {\n            Some(custom_expr) => match custom_expr {\n                CustomExpr::Map(exprs) | CustomExpr::Multiset(exprs) => exprs\n                    .iter()\n                    .map(|e| PySqlArg::new(Some(e.clone()), None))\n                    .collect(),\n                _ => vec![],\n            },\n            None => match &self.expr {\n                Some(expr) => match expr {\n                    Expr::Array(array) => array\n                        .elem\n                        .iter()\n                        .map(|e| PySqlArg::new(Some(e.clone()), None))\n                        .collect(),\n                    _ => vec![],\n                },\n                None => return self.expected(\"\"),\n            },\n        })\n    }\n\n    #[pyo3(name = \"getKwargs\")]\n    pub fn get_kwargs(&self) -> PyResult<Vec<(String, PySqlArg)>> {\n        Ok(match &self.custom {\n            Some(CustomExpr::Nested(kwargs)) => kwargs.clone(),\n            _ => vec![],\n        })\n    }\n\n    #[pyo3(name = \"getSqlType\")]\n    pub fn get_sql_type(&self) -> PyResult<SqlTypeName> {\n        Ok(match &self.custom {\n            Some(custom_expr) => match custom_expr {\n                CustomExpr::Map(_) => SqlTypeName::MAP,\n                CustomExpr::Multiset(_) => SqlTypeName::MULTISET,\n                _ => return self.expected(\"Map or multiset\"),\n            },\n            None => match &self.expr {\n                Some(Expr::Array(_)) => SqlTypeName::ARRAY,\n                Some(Expr::Identifier(Ident { .. })) => SqlTypeName::VARCHAR,\n                Some(Expr::Value(scalar)) => match scalar {\n                    Value::Boolean(_) => SqlTypeName::BOOLEAN,\n                    Value::Number(_, false) => SqlTypeName::BIGINT,\n                    Value::SingleQuotedString(_) => SqlTypeName::VARCHAR,\n                    _ => return self.expected(\"Boolean, integer, float, or single-quoted string\"),\n                },\n                Some(Expr::UnaryOp {\n                    op: UnaryOperator::Minus,\n                    expr,\n                }) => match &**expr {\n                    Expr::Value(Value::Number(_, false)) => SqlTypeName::BIGINT,\n                    _ => return self.expected(\"Integer or float\"),\n                },\n                Some(_) => return self.expected(\"Array, identifier, or scalar\"),\n                None => return self.expected(\"\"),\n            },\n        })\n    }\n\n    #[pyo3(name = \"getSqlValue\")]\n    pub fn get_sql_value(&self) -> PyResult<String> {\n        Ok(match &self.custom {\n            None => match &self.expr {\n                Some(Expr::Identifier(Ident { value, .. })) => value.to_string(),\n                Some(Expr::Value(scalar)) => match scalar {\n                    Value::Boolean(true) => \"1\".to_string(),\n                    Value::Boolean(false) => \"\".to_string(),\n                    Value::SingleQuotedString(string) => string.to_string(),\n                    Value::Number(value, false) => value.to_string(),\n                    _ => return self.expected(\"Boolean, integer, float, or single-quoted string\"),\n                },\n                Some(Expr::UnaryOp {\n                    op: UnaryOperator::Minus,\n                    expr,\n                }) => match &**expr {\n                    Expr::Value(Value::Number(value, false)) => format!(\"-{value}\"),\n                    _ => return self.expected(\"Integer or float\"),\n                },\n                _ => return self.expected(\"Array, identifier, or scalar\"),\n            },\n            _ => return self.expected(\"Standard sqlparser AST expression\"),\n        })\n    }\n}\n\n/// Dask-SQL extension DDL for `CREATE MODEL`\n#[derive(Debug, Clone, PartialEq, Eq)]\npub struct CreateModel {\n    /// schema and model name, i.e. 'schema_name.model_name'\n    pub schema_name: Option<String>,\n    pub model_name: String,\n    /// input query\n    pub select: DaskStatement,\n    /// whether or not IF NOT EXISTS was specified\n    pub if_not_exists: bool,\n    /// whether or not OR REPLACE was specified\n    pub or_replace: bool,\n    /// kwargs specified in WITH\n    pub with_options: Vec<(String, PySqlArg)>,\n}\n\n/// Dask-SQL extension DDL for `CREATE EXPERIMENT`\n#[derive(Debug, Clone, PartialEq, Eq)]\npub struct CreateExperiment {\n    /// schema and experiment name, i.e. 'schema_name.experiment_name'\n    pub schema_name: Option<String>,\n    pub experiment_name: String,\n    /// input query\n    pub select: DaskStatement,\n    /// whether or not IF NOT EXISTS was specified\n    pub if_not_exists: bool,\n    /// whether or not OR REPLACE was specified\n    pub or_replace: bool,\n    /// kwargs specified in WITH\n    pub with_options: Vec<(String, PySqlArg)>,\n}\n\n/// Dask-SQL extension DDL for `PREDICT`\n#[derive(Debug, Clone, PartialEq, Eq)]\npub struct PredictModel {\n    /// schema and model name, i.e. 'schema_name.model_name'\n    pub schema_name: Option<String>,\n    pub model_name: String,\n    /// input query\n    pub select: DaskStatement,\n}\n\n/// Dask-SQL extension DDL for `CREATE SCHEMA`\n#[derive(Debug, Clone, PartialEq, Eq)]\npub struct CreateCatalogSchema {\n    /// schema name\n    pub schema_name: String,\n    /// whether or not IF NOT EXISTS was specified\n    pub if_not_exists: bool,\n    /// whether or not OR REPLACE was specified\n    pub or_replace: bool,\n}\n\n/// Dask-SQL extension DDL for `CREATE TABLE ... WITH`\n#[derive(Debug, Clone, PartialEq, Eq)]\npub struct CreateTable {\n    /// schema and table name, i.e. 'schema_name.table_name'\n    pub schema_name: Option<String>,\n    pub table_name: String,\n    /// whether or not IF NOT EXISTS was specified\n    pub if_not_exists: bool,\n    /// whether or not OR REPLACE was specified\n    pub or_replace: bool,\n    /// kwargs specified in WITH\n    pub with_options: Vec<(String, PySqlArg)>,\n}\n\n/// Dask-SQL extension DDL for `DROP MODEL`\n#[derive(Debug, Clone, PartialEq, Eq)]\npub struct DropModel {\n    /// schema and model name, i.e. 'schema_name.table_name'\n    pub schema_name: Option<String>,\n    pub model_name: String,\n    /// whether or not IF NOT EXISTS was specified\n    pub if_exists: bool,\n}\n\n/// Dask-SQL extension DDL for `EXPORT MODEL`\n#[derive(Debug, Clone, PartialEq, Eq)]\npub struct ExportModel {\n    /// schema and model name, i.e. 'schema_name.table_name'\n    pub schema_name: Option<String>,\n    pub model_name: String,\n    /// kwargs specified in WITH\n    pub with_options: Vec<(String, PySqlArg)>,\n}\n\n/// Dask-SQL extension DDL for `DESCRIBE MODEL`\n#[derive(Debug, Clone, PartialEq, Eq)]\npub struct DescribeModel {\n    /// schema and model name, i.e. 'schema_name.table_name'\n    pub schema_name: Option<String>,\n    pub model_name: String,\n}\n\n/// Dask-SQL extension DDL for `SHOW SCHEMAS`\n#[derive(Debug, Clone, PartialEq, Eq)]\npub struct ShowSchemas {\n    /// optional catalog name\n    pub catalog_name: Option<String>,\n    /// optional LIKE identifier\n    pub like: Option<String>,\n}\n\n/// Dask-SQL extension DDL for `SHOW TABLES FROM`\n#[derive(Debug, Clone, PartialEq, Eq)]\npub struct ShowTables {\n    /// catalog and schema name, i.e. 'catalog_name.schema_name'\n    pub catalog_name: Option<String>,\n    pub schema_name: Option<String>,\n}\n\n/// Dask-SQL extension DDL for `SHOW COLUMNS FROM`\n#[derive(Debug, Clone, PartialEq, Eq)]\npub struct ShowColumns {\n    /// schema and table name, i.e. 'schema_name.table_name'\n    pub schema_name: Option<String>,\n    pub table_name: String,\n}\n\n/// Dask-SQL extension DDL for `SHOW MODELS`\n#[derive(Debug, Clone, PartialEq, Eq)]\npub struct ShowModels {\n    pub schema_name: Option<String>,\n}\n\n/// Dask-SQL extension DDL for `USE SCHEMA`\n#[derive(Debug, Clone, PartialEq, Eq)]\npub struct DropSchema {\n    /// schema name\n    pub schema_name: String,\n    /// whether or not IF NOT EXISTS was specified\n    pub if_exists: bool,\n}\n\n/// Dask-SQL extension DDL for `USE SCHEMA`\n#[derive(Debug, Clone, PartialEq, Eq)]\npub struct UseSchema {\n    /// schema name\n    pub schema_name: String,\n}\n\n/// Dask-SQL extension DDL for `ANALYZE TABLE`\n#[derive(Debug, Clone, PartialEq, Eq)]\npub struct AnalyzeTable {\n    /// schema and table name, i.e. 'schema_name.table_name'\n    pub schema_name: Option<String>,\n    pub table_name: String,\n    /// columns to analyze in specified table\n    pub columns: Vec<String>,\n}\n\n/// Dask-SQL extension DDL for `ALTER TABLE`\n#[derive(Debug, Clone, PartialEq, Eq)]\npub struct AlterTable {\n    pub old_table_name: String,\n    pub new_table_name: String,\n    pub schema_name: Option<String>,\n    pub if_exists: bool,\n}\n\n/// Dask-SQL extension DDL for `ALTER SCHEMA`\n#[derive(Debug, Clone, PartialEq, Eq)]\npub struct AlterSchema {\n    pub old_schema_name: String,\n    pub new_schema_name: String,\n}\n\n/// Dask-SQL Statement representations.\n///\n/// Tokens parsed by `DaskParser` are converted into these values.\n#[derive(Debug, Clone, PartialEq, Eq)]\npub enum DaskStatement {\n    /// ANSI SQL AST node\n    Statement(Box<SQLStatement>),\n    /// Extension: `CREATE MODEL`\n    CreateModel(Box<CreateModel>),\n    /// Extension: `CREATE EXPERIMENT`\n    CreateExperiment(Box<CreateExperiment>),\n    /// Extension: `CREATE SCHEMA`\n    CreateCatalogSchema(Box<CreateCatalogSchema>),\n    /// Extension: `CREATE TABLE`\n    CreateTable(Box<CreateTable>),\n    /// Extension: `DROP MODEL`\n    DropModel(Box<DropModel>),\n    /// Extension: `EXPORT MODEL`\n    ExportModel(Box<ExportModel>),\n    /// Extension: `DESCRIBE MODEL`\n    DescribeModel(Box<DescribeModel>),\n    /// Extension: `PREDICT`\n    PredictModel(Box<PredictModel>),\n    // Extension: `SHOW SCHEMAS`\n    ShowSchemas(Box<ShowSchemas>),\n    // Extension: `SHOW TABLES FROM`\n    ShowTables(Box<ShowTables>),\n    // Extension: `SHOW COLUMNS FROM`\n    ShowColumns(Box<ShowColumns>),\n    // Extension: `SHOW COLUMNS FROM`\n    ShowModels(Box<ShowModels>),\n    // Exntension: `DROP SCHEMA`\n    DropSchema(Box<DropSchema>),\n    // Extension: `USE SCHEMA`\n    UseSchema(Box<UseSchema>),\n    // Extension: `ANALYZE TABLE`\n    AnalyzeTable(Box<AnalyzeTable>),\n    // Extension: `ALTER TABLE`\n    AlterTable(Box<AlterTable>),\n    // Extension: `ALTER SCHEMA`\n    AlterSchema(Box<AlterSchema>),\n}\n\n/// SQL Parser\npub struct DaskParser<'a> {\n    parser: Parser<'a>,\n}\n\nimpl<'a> DaskParser<'a> {\n    #[allow(dead_code)]\n    /// Parse the specified tokens\n    pub fn new(sql: &str) -> Result<Self, ParserError> {\n        let dialect = &DaskDialect {};\n        DaskParser::new_with_dialect(sql, dialect)\n    }\n\n    /// Parse the specified tokens with dialect\n    pub fn new_with_dialect(sql: &str, dialect: &'a dyn Dialect) -> Result<Self, ParserError> {\n        let mut tokenizer = Tokenizer::new(dialect, sql);\n        let tokens = tokenizer.tokenize()?;\n\n        Ok(DaskParser {\n            parser: Parser::new(dialect).with_tokens(tokens),\n        })\n    }\n\n    #[allow(dead_code)]\n    /// Parse a SQL statement and produce a set of statements with dialect\n    pub fn parse_sql(sql: &str) -> Result<VecDeque<DaskStatement>, ParserError> {\n        let dialect = &DaskDialect {};\n        DaskParser::parse_sql_with_dialect(sql, dialect)\n    }\n\n    /// Parse a SQL statement and produce a set of statements\n    pub fn parse_sql_with_dialect(\n        sql: &str,\n        dialect: &dyn Dialect,\n    ) -> Result<VecDeque<DaskStatement>, ParserError> {\n        let mut parser = DaskParser::new_with_dialect(sql, dialect)?;\n        let mut stmts = VecDeque::new();\n        let mut expecting_statement_delimiter = false;\n        loop {\n            // ignore empty statements (between successive statement delimiters)\n            while parser.parser.consume_token(&Token::SemiColon) {\n                expecting_statement_delimiter = false;\n            }\n\n            if parser.parser.peek_token() == Token::EOF {\n                break;\n            }\n            if expecting_statement_delimiter {\n                return parser.expected(\"end of statement\", parser.parser.peek_token());\n            }\n\n            let statement = parser.parse_statement()?;\n            stmts.push_back(statement);\n            expecting_statement_delimiter = true;\n        }\n        Ok(stmts)\n    }\n\n    /// Report unexpected token\n    fn expected<T>(&self, expected: &str, found: TokenWithLocation) -> Result<T, ParserError> {\n        parser_err!(format!(\n            \"Expected {}, found: {} at line {} column {}\",\n            expected, found.token, found.location.line, found.location.column\n        ))\n    }\n\n    /// Parse a new expression\n    pub fn parse_statement(&mut self) -> Result<DaskStatement, ParserError> {\n        match self.parser.peek_token().token {\n            Token::Word(w) => {\n                match w.keyword {\n                    Keyword::CREATE => {\n                        // move one token forward\n                        self.parser.next_token();\n                        // use custom parsing\n                        self.parse_create()\n                    }\n                    Keyword::DROP => {\n                        // move one token forward\n                        self.parser.next_token();\n                        // use custom parsing\n                        self.parse_drop()\n                    }\n                    Keyword::SELECT => {\n                        // Check for PREDICT token in statement\n                        let mut cnt = 1;\n                        loop {\n                            match self.parser.next_token().token {\n                                Token::Word(w) => {\n                                    match w.value.to_lowercase().as_str() {\n                                        \"predict\" => {\n                                            return self.parse_predict_model();\n                                        }\n                                        _ => {\n                                            // Keep looking for PREDICT\n                                            cnt += 1;\n                                            continue;\n                                        }\n                                    }\n                                }\n                                Token::EOF => {\n                                    break;\n                                }\n                                _ => {\n                                    // Keep looking for PREDICT\n                                    cnt += 1;\n                                    continue;\n                                }\n                            }\n                        }\n\n                        // Reset the parser back to where we started\n                        for _ in 0..cnt {\n                            self.parser.prev_token();\n                        }\n\n                        // use the native parser\n                        Ok(DaskStatement::Statement(Box::from(\n                            self.parser.parse_statement()?,\n                        )))\n                    }\n                    Keyword::SHOW => {\n                        // move one token forward\n                        self.parser.next_token();\n                        // use custom parsing\n                        self.parse_show()\n                    }\n                    Keyword::DESCRIBE => {\n                        // move one token forwrd\n                        self.parser.next_token();\n                        // use custom parsing\n                        self.parse_describe()\n                    }\n                    Keyword::USE => {\n                        // move one token forwrd\n                        self.parser.next_token();\n                        // use custom parsing\n                        self.parse_use()\n                    }\n                    Keyword::ANALYZE => {\n                        // move one token foward\n                        self.parser.next_token();\n                        self.parse_analyze()\n                    }\n                    Keyword::ALTER => {\n                        // move one token forward\n                        self.parser.next_token();\n                        self.parse_alter()\n                    }\n                    _ => {\n                        match w.value.to_lowercase().as_str() {\n                            \"export\" => {\n                                // move one token forwrd\n                                self.parser.next_token();\n                                // use custom parsing\n                                self.parse_export_model()\n                            }\n                            _ => {\n                                // use the native parser\n                                Ok(DaskStatement::Statement(Box::from(\n                                    self.parser.parse_statement()?,\n                                )))\n                            }\n                        }\n                    }\n                }\n            }\n            _ => {\n                // use the native parser\n                Ok(DaskStatement::Statement(Box::from(\n                    self.parser.parse_statement()?,\n                )))\n            }\n        }\n    }\n\n    /// Parse a SQL CREATE statement\n    pub fn parse_create(&mut self) -> Result<DaskStatement, ParserError> {\n        let or_replace = self.parser.parse_keywords(&[Keyword::OR, Keyword::REPLACE]);\n        match self.parser.peek_token().token {\n            Token::Word(w) => {\n                match w.value.to_lowercase().as_str() {\n                    \"model\" => {\n                        // move one token forward\n                        self.parser.next_token();\n\n                        let if_not_exists = self.parser.parse_keywords(&[\n                            Keyword::IF,\n                            Keyword::NOT,\n                            Keyword::EXISTS,\n                        ]);\n\n                        // use custom parsing\n                        self.parse_create_model(if_not_exists, or_replace)\n                    }\n                    \"experiment\" => {\n                        // move one token forward\n                        self.parser.next_token();\n\n                        let if_not_exists = self.parser.parse_keywords(&[\n                            Keyword::IF,\n                            Keyword::NOT,\n                            Keyword::EXISTS,\n                        ]);\n\n                        // use custom parsing\n                        self.parse_create_experiment(if_not_exists, or_replace)\n                    }\n                    \"schema\" => {\n                        // move one token forward\n                        self.parser.next_token();\n\n                        let if_not_exists = self.parser.parse_keywords(&[\n                            Keyword::IF,\n                            Keyword::NOT,\n                            Keyword::EXISTS,\n                        ]);\n\n                        // use custom parsing\n                        self.parse_create_schema(if_not_exists, or_replace)\n                    }\n                    \"table\" => {\n                        // move one token forward\n                        self.parser.next_token();\n\n                        // use custom parsing\n                        self.parse_create_table(true, or_replace)\n                    }\n                    \"view\" => {\n                        // move one token forward\n                        self.parser.next_token();\n                        // use custom parsing\n                        self.parse_create_table(false, or_replace)\n                    }\n                    _ => {\n                        if or_replace {\n                            // Go back two tokens if OR REPLACE was consumed\n                            self.parser.prev_token();\n                            self.parser.prev_token();\n                        }\n                        // use the native parser\n                        Ok(DaskStatement::Statement(Box::from(\n                            self.parser.parse_create()?,\n                        )))\n                    }\n                }\n            }\n            _ => {\n                if or_replace {\n                    // Go back two tokens if OR REPLACE was consumed\n                    self.parser.prev_token();\n                    self.parser.prev_token();\n                }\n                // use the native parser\n                Ok(DaskStatement::Statement(Box::from(\n                    self.parser.parse_create()?,\n                )))\n            }\n        }\n    }\n\n    /// Parse a SQL DROP statement\n    pub fn parse_drop(&mut self) -> Result<DaskStatement, ParserError> {\n        match self.parser.peek_token().token {\n            Token::Word(w) => {\n                match w.value.to_lowercase().as_str() {\n                    \"model\" => {\n                        // move one token forward\n                        self.parser.next_token();\n                        // use custom parsing\n                        self.parse_drop_model()\n                    }\n                    \"schema\" => {\n                        // move one token forward\n                        self.parser.next_token();\n                        // use custom parsing\n\n                        let if_exists = self.parser.parse_keywords(&[Keyword::IF, Keyword::EXISTS]);\n\n                        let schema_name = self.parser.parse_identifier()?;\n\n                        let drop_schema = DropSchema {\n                            schema_name: schema_name.value,\n                            if_exists,\n                        };\n                        Ok(DaskStatement::DropSchema(Box::new(drop_schema)))\n                    }\n                    _ => {\n                        // use the native parser\n                        Ok(DaskStatement::Statement(Box::from(\n                            self.parser.parse_drop()?,\n                        )))\n                    }\n                }\n            }\n            _ => {\n                // use the native parser\n                Ok(DaskStatement::Statement(Box::from(\n                    self.parser.parse_drop()?,\n                )))\n            }\n        }\n    }\n\n    /// Parse a SQL SHOW statement\n    pub fn parse_show(&mut self) -> Result<DaskStatement, ParserError> {\n        match self.parser.peek_token().token {\n            Token::Word(w) => {\n                match w.value.to_lowercase().as_str() {\n                    \"schemas\" => {\n                        // move one token forward\n                        self.parser.next_token();\n                        // use custom parsing\n                        self.parse_show_schemas()\n                    }\n                    \"tables\" => {\n                        // move one token forward\n                        self.parser.next_token();\n\n                        // If non ansi ... `FROM {schema_name}` is present custom parse\n                        // otherwise use sqlparser-rs\n                        match self.parser.peek_token().token {\n                            Token::Word(w) => {\n                                match w.value.to_lowercase().as_str() {\n                                    \"from\" => {\n                                        // move one token forward\n                                        self.parser.next_token();\n                                        // use custom parsing\n                                        self.parse_show_tables()\n                                    }\n                                    _ => {\n                                        self.parser.prev_token();\n                                        // use the native parser\n                                        Ok(DaskStatement::Statement(Box::from(\n                                            self.parser.parse_show()?,\n                                        )))\n                                    }\n                                }\n                            }\n                            _ => self.parse_show_tables(),\n                        }\n                    }\n                    \"columns\" => {\n                        self.parser.next_token();\n                        // use custom parsing\n                        self.parse_show_columns()\n                    }\n                    \"models\" => {\n                        self.parser.next_token();\n                        // use custom parsing\n                        self.parse_show_models()\n                    }\n                    _ => {\n                        // use the native parser\n                        Ok(DaskStatement::Statement(Box::from(\n                            self.parser.parse_show()?,\n                        )))\n                    }\n                }\n            }\n            _ => {\n                // use the native parser\n                Ok(DaskStatement::Statement(Box::from(\n                    self.parser.parse_show()?,\n                )))\n            }\n        }\n    }\n\n    /// Parse a SQL DESCRIBE statement\n    pub fn parse_describe(&mut self) -> Result<DaskStatement, ParserError> {\n        match self.parser.peek_token().token {\n            Token::Word(w) => {\n                match w.value.to_lowercase().as_str() {\n                    \"model\" => {\n                        self.parser.next_token();\n                        // use custom parsing\n                        self.parse_describe_model()\n                    }\n                    _ => {\n                        // use the native parser\n                        Ok(DaskStatement::Statement(Box::from(\n                            self.parser.parse_show()?,\n                        )))\n                    }\n                }\n            }\n            _ => {\n                // use the native parser\n                Ok(DaskStatement::Statement(Box::from(\n                    self.parser.parse_show()?,\n                )))\n            }\n        }\n    }\n\n    /// Parse a SQL USE SCHEMA statement\n    pub fn parse_use(&mut self) -> Result<DaskStatement, ParserError> {\n        match self.parser.peek_token().token {\n            Token::Word(w) => {\n                match w.value.to_lowercase().as_str() {\n                    \"schema\" => {\n                        // move one token forward\n                        self.parser.next_token();\n                        // use custom parsing\n                        let schema_name = self.parser.parse_identifier()?;\n\n                        let use_schema = UseSchema {\n                            schema_name: schema_name.value,\n                        };\n                        Ok(DaskStatement::UseSchema(Box::new(use_schema)))\n                    }\n                    _ => Ok(DaskStatement::Statement(Box::from(\n                        self.parser.parse_show()?,\n                    ))),\n                }\n            }\n            _ => Ok(DaskStatement::Statement(Box::from(\n                self.parser.parse_show()?,\n            ))),\n        }\n    }\n\n    /// Parse a SQL ANALYZE statement\n    pub fn parse_analyze(&mut self) -> Result<DaskStatement, ParserError> {\n        match self.parser.peek_token().token {\n            Token::Word(w) => {\n                match w.value.to_lowercase().as_str() {\n                    \"table\" => {\n                        // move one token forward\n                        self.parser.next_token();\n                        // use custom parsing\n                        self.parse_analyze_table()\n                    }\n                    _ => {\n                        // use the native parser\n                        Ok(DaskStatement::Statement(Box::from(\n                            self.parser.parse_analyze()?,\n                        )))\n                    }\n                }\n            }\n            _ => {\n                // use the native parser\n                Ok(DaskStatement::Statement(Box::from(\n                    self.parser.parse_analyze()?,\n                )))\n            }\n        }\n    }\n\n    /// Parse a SQL ALTER statement\n    pub fn parse_alter(&mut self) -> Result<DaskStatement, ParserError> {\n        match self.parser.peek_token().token {\n            Token::Word(w) => {\n                match w.keyword {\n                    Keyword::TABLE => {\n                        self.parser.next_token();\n                        self.parse_alter_table()\n                    }\n                    Keyword::SCHEMA => {\n                        self.parser.next_token();\n                        self.parse_alter_schema()\n                    }\n                    _ => {\n                        // use the native parser\n                        Ok(DaskStatement::Statement(Box::from(\n                            self.parser.parse_alter()?,\n                        )))\n                    }\n                }\n            }\n            _ => {\n                // use the native parser\n                Ok(DaskStatement::Statement(Box::from(\n                    self.parser.parse_alter()?,\n                )))\n            }\n        }\n    }\n\n    /// Parse a SQL PREDICT statement\n    pub fn parse_predict_model(&mut self) -> Result<DaskStatement, ParserError> {\n        // PREDICT(\n        //     MODEL model_name,\n        //     SQLStatement\n        // )\n        self.parser.expect_token(&Token::LParen)?;\n\n        let is_model = match self.parser.next_token().token {\n            Token::Word(w) => matches!(w.value.to_lowercase().as_str(), \"model\"),\n            _ => false,\n        };\n        if !is_model {\n            return Err(ParserError::ParserError(\n                \"parse_predict_model: Expected `MODEL`\".to_string(),\n            ));\n        }\n\n        let (schema_name, model_name) =\n            DaskParserUtils::elements_from_object_name(&self.parser.parse_object_name()?)?;\n        self.parser.expect_token(&Token::Comma)?;\n\n        // Limit our input to  ANALYZE, DESCRIBE, SELECT, SHOW statements\n        // TODO: find a more sophisticated way to allow any statement that would return a table\n        self.parser.expect_one_of_keywords(&[\n            Keyword::SELECT,\n            Keyword::DESCRIBE,\n            Keyword::SHOW,\n            Keyword::ANALYZE,\n        ])?;\n        self.parser.prev_token();\n\n        let select = self.parse_statement()?;\n\n        self.parser.expect_token(&Token::RParen)?;\n\n        let predict = PredictModel {\n            schema_name,\n            model_name,\n            select,\n        };\n        Ok(DaskStatement::PredictModel(Box::new(predict)))\n    }\n\n    /// Parse Dask-SQL CREATE MODEL statement\n    fn parse_create_model(\n        &mut self,\n        if_not_exists: bool,\n        or_replace: bool,\n    ) -> Result<DaskStatement, ParserError> {\n        // Parse schema and model name\n        let (schema_name, model_name) =\n            DaskParserUtils::elements_from_object_name(&self.parser.parse_object_name()?)?;\n\n        // Parse WITH options\n        self.parser.expect_keyword(Keyword::WITH)?;\n        self.parser.expect_token(&Token::LParen)?;\n        let with_options = self.parse_comma_separated(DaskParser::parse_key_value_pair)?;\n        self.parser.expect_token(&Token::RParen)?;\n\n        // Parse the nested query statement\n        self.parser.expect_keyword(Keyword::AS)?;\n        self.parser.expect_token(&Token::LParen)?;\n\n        // Limit our input to  ANALYZE, DESCRIBE, SELECT, SHOW statements\n        // TODO: find a more sophisticated way to allow any statement that would return a table\n        self.parser.expect_one_of_keywords(&[\n            Keyword::SELECT,\n            Keyword::DESCRIBE,\n            Keyword::SHOW,\n            Keyword::ANALYZE,\n        ])?;\n        self.parser.prev_token();\n\n        let select = self.parse_statement()?;\n\n        self.parser.expect_token(&Token::RParen)?;\n\n        let create = CreateModel {\n            schema_name,\n            model_name,\n            select,\n            if_not_exists,\n            or_replace,\n            with_options,\n        };\n        Ok(DaskStatement::CreateModel(Box::new(create)))\n    }\n\n    // copied from sqlparser crate and adapted to work with DaskParser\n    fn parse_comma_separated<T, F>(&mut self, mut f: F) -> Result<Vec<T>, ParserError>\n    where\n        F: FnMut(&mut DaskParser<'a>) -> Result<T, ParserError>,\n    {\n        let mut values = vec![];\n        loop {\n            values.push(f(self)?);\n            if !self.parser.consume_token(&Token::Comma) {\n                break;\n            }\n        }\n        Ok(values)\n    }\n\n    fn parse_key_value_pair(&mut self) -> Result<(String, PySqlArg), ParserError> {\n        let key = self.parser.parse_identifier()?;\n        self.parser.expect_token(&Token::Eq)?;\n        match self.parser.next_token().token {\n            Token::LParen => {\n                let key_value_pairs =\n                    self.parse_comma_separated(DaskParser::parse_key_value_pair)?;\n                self.parser.expect_token(&Token::RParen)?;\n                Ok((\n                    key.value,\n                    PySqlArg::new(None, Some(CustomExpr::Nested(key_value_pairs))),\n                ))\n            }\n            Token::Word(w) if w.value.to_lowercase().as_str() == \"map\" => {\n                // TODO this does not support map or multiset expressions within the map\n                self.parser.expect_token(&Token::LBracket)?;\n                let values = self.parser.parse_comma_separated(Parser::parse_expr)?;\n                self.parser.expect_token(&Token::RBracket)?;\n                Ok((\n                    key.value,\n                    PySqlArg::new(None, Some(CustomExpr::Map(values))),\n                ))\n            }\n            Token::Word(w) if w.value.to_lowercase().as_str() == \"multiset\" => {\n                // TODO this does not support map or multiset expressions within the multiset\n                self.parser.expect_token(&Token::LBracket)?;\n                let values = self.parser.parse_comma_separated(Parser::parse_expr)?;\n                self.parser.expect_token(&Token::RBracket)?;\n                Ok((\n                    key.value,\n                    PySqlArg::new(None, Some(CustomExpr::Multiset(values))),\n                ))\n            }\n            _ => {\n                self.parser.prev_token();\n                Ok((\n                    key.value,\n                    PySqlArg::new(Some(self.parser.parse_expr()?), None),\n                ))\n            }\n        }\n    }\n\n    /// Parse Dask-SQL CREATE EXPERIMENT statement\n    fn parse_create_experiment(\n        &mut self,\n        if_not_exists: bool,\n        or_replace: bool,\n    ) -> Result<DaskStatement, ParserError> {\n        // Parse schema and model name\n        let (schema_name, experiment_name) =\n            DaskParserUtils::elements_from_object_name(&self.parser.parse_object_name()?)?;\n\n        // Parse WITH options\n        self.parser.expect_keyword(Keyword::WITH)?;\n        self.parser.expect_token(&Token::LParen)?;\n        let with_options = self.parse_comma_separated(DaskParser::parse_key_value_pair)?;\n        self.parser.expect_token(&Token::RParen)?;\n\n        // Parse the nested query statement\n        self.parser.expect_keyword(Keyword::AS)?;\n        self.parser.expect_token(&Token::LParen)?;\n\n        // Limit our input to  ANALYZE, DESCRIBE, SELECT, SHOW statements\n        // TODO: find a more sophisticated way to allow any statement that would return a table\n        self.parser.expect_one_of_keywords(&[\n            Keyword::SELECT,\n            Keyword::DESCRIBE,\n            Keyword::SHOW,\n            Keyword::ANALYZE,\n        ])?;\n        self.parser.prev_token();\n\n        let select = self.parse_statement()?;\n\n        self.parser.expect_token(&Token::RParen)?;\n\n        let create = CreateExperiment {\n            schema_name,\n            experiment_name,\n            select,\n            if_not_exists,\n            or_replace,\n            with_options,\n        };\n        Ok(DaskStatement::CreateExperiment(Box::new(create)))\n    }\n\n    /// Parse Dask-SQL CREATE {IF NOT EXISTS | OR REPLACE} SCHEMA ... statement\n    fn parse_create_schema(\n        &mut self,\n        if_not_exists: bool,\n        or_replace: bool,\n    ) -> Result<DaskStatement, ParserError> {\n        let schema_name = self.parser.parse_identifier()?.value;\n\n        let create = CreateCatalogSchema {\n            schema_name,\n            if_not_exists,\n            or_replace,\n        };\n        Ok(DaskStatement::CreateCatalogSchema(Box::new(create)))\n    }\n\n    /// Parse Dask-SQL CREATE [OR REPLACE] TABLE ... statement\n    ///\n    /// # Arguments\n    ///\n    /// * `is_table` - Whether the \"table\" is a \"TABLE\" or \"VIEW\", True if \"TABLE\" and False otherwise.\n    /// * `or_replace` - True if the \"TABLE\" or \"VIEW\" should be replaced and False otherwise\n    fn parse_create_table(\n        &mut self,\n        is_table: bool,\n        or_replace: bool,\n    ) -> Result<DaskStatement, ParserError> {\n        // parse [IF NOT EXISTS] `table_name` AS|WITH\n        let if_not_exists =\n            self.parser\n                .parse_keywords(&[Keyword::IF, Keyword::NOT, Keyword::EXISTS]);\n\n        let _table_name = self.parser.parse_identifier();\n        let after_name_token = self.parser.peek_token().token;\n\n        match after_name_token {\n            Token::Word(w) => {\n                match w.value.to_lowercase().as_str() {\n                    \"as\" => {\n                        self.parser.prev_token();\n                        if if_not_exists {\n                            // Go back three tokens if IF NOT EXISTS was consumed, native parser consumes these tokens as well\n                            self.parser.prev_token();\n                            self.parser.prev_token();\n                            self.parser.prev_token();\n                        }\n\n                        // True if TABLE and False if VIEW\n                        if is_table {\n                            Ok(DaskStatement::Statement(Box::from(\n                                self.parser\n                                    .parse_create_table(or_replace, false, None, false)?,\n                            )))\n                        } else {\n                            self.parser.prev_token();\n                            Ok(DaskStatement::Statement(Box::from(\n                                self.parser.parse_create_view(or_replace)?,\n                            )))\n                        }\n                    }\n                    \"with\" => {\n                        // `table_name` has been parsed at this point but is needed, reset consumption\n                        self.parser.prev_token();\n\n                        // Parse schema and table name\n                        let (schema_name, table_name) = DaskParserUtils::elements_from_object_name(\n                            &self.parser.parse_object_name()?,\n                        )?;\n\n                        // Parse WITH options\n                        self.parser.expect_keyword(Keyword::WITH)?;\n                        self.parser.expect_token(&Token::LParen)?;\n                        let with_options =\n                            self.parse_comma_separated(DaskParser::parse_key_value_pair)?;\n                        self.parser.expect_token(&Token::RParen)?;\n\n                        let create = CreateTable {\n                            schema_name,\n                            table_name,\n                            if_not_exists,\n                            or_replace,\n                            with_options,\n                        };\n                        Ok(DaskStatement::CreateTable(Box::new(create)))\n                    }\n                    _ => self.expected(\"'as' or 'with'\", self.parser.peek_token()),\n                }\n            }\n            _ => {\n                self.parser.prev_token();\n                if if_not_exists {\n                    // Go back three tokens if IF NOT EXISTS was consumed\n                    self.parser.prev_token();\n                    self.parser.prev_token();\n                    self.parser.prev_token();\n                }\n                // use the native parser\n                Ok(DaskStatement::Statement(Box::from(\n                    self.parser\n                        .parse_create_table(or_replace, false, None, false)?,\n                )))\n            }\n        }\n    }\n\n    /// Parse Dask-SQL EXPORT MODEL statement\n    fn parse_export_model(&mut self) -> Result<DaskStatement, ParserError> {\n        let is_model = match self.parser.next_token().token {\n            Token::Word(w) => matches!(w.value.to_lowercase().as_str(), \"model\"),\n            _ => false,\n        };\n        if !is_model {\n            return Err(ParserError::ParserError(\n                \"parse_export_model: Expected `MODEL`\".to_string(),\n            ));\n        }\n\n        // Parse schema and model name\n        let (schema_name, model_name) =\n            DaskParserUtils::elements_from_object_name(&self.parser.parse_object_name()?)?;\n\n        // Parse WITH options\n        self.parser.expect_keyword(Keyword::WITH)?;\n        self.parser.expect_token(&Token::LParen)?;\n        let with_options = self.parse_comma_separated(DaskParser::parse_key_value_pair)?;\n        self.parser.expect_token(&Token::RParen)?;\n\n        let export = ExportModel {\n            schema_name,\n            model_name,\n            with_options,\n        };\n        Ok(DaskStatement::ExportModel(Box::new(export)))\n    }\n\n    /// Parse Dask-SQL DROP MODEL statement\n    fn parse_drop_model(&mut self) -> Result<DaskStatement, ParserError> {\n        let if_exists = self.parser.parse_keywords(&[Keyword::IF, Keyword::EXISTS]);\n        // Parse schema and model name\n        let (schema_name, model_name) =\n            DaskParserUtils::elements_from_object_name(&self.parser.parse_object_name()?)?;\n\n        let drop = DropModel {\n            schema_name,\n            model_name,\n            if_exists,\n        };\n        Ok(DaskStatement::DropModel(Box::new(drop)))\n    }\n\n    /// Parse Dask-SQL DESRIBE MODEL statement\n    fn parse_describe_model(&mut self) -> Result<DaskStatement, ParserError> {\n        // Parse schema and model name\n        let (schema_name, model_name) =\n            DaskParserUtils::elements_from_object_name(&self.parser.parse_object_name()?)?;\n\n        let describe = DescribeModel {\n            schema_name,\n            model_name,\n        };\n        Ok(DaskStatement::DescribeModel(Box::new(describe)))\n    }\n\n    /// Parse Dask-SQL SHOW SCHEMAS statement\n    fn parse_show_schemas(&mut self) -> Result<DaskStatement, ParserError> {\n        // parse optional `FROM` clause\n        let catalog_name = match self.parser.peek_token().token {\n            Token::Word(w) => {\n                match w.keyword {\n                    Keyword::FROM => {\n                        // move one token forward\n                        self.parser.next_token();\n                        // use custom parsing\n                        Some(self.parser.parse_identifier()?.value)\n                    }\n                    _ => None,\n                }\n            }\n            _ => None,\n        };\n        // parse optional `LIKE` clause\n        let like = match self.parser.peek_token().token {\n            Token::Word(w) => {\n                match w.keyword {\n                    Keyword::LIKE => {\n                        // move one token forward\n                        self.parser.next_token();\n                        // use custom parsing\n                        Some(self.parser.parse_identifier()?.value)\n                    }\n                    _ => None,\n                }\n            }\n            _ => None,\n        };\n\n        Ok(DaskStatement::ShowSchemas(Box::new(ShowSchemas {\n            catalog_name,\n            like,\n        })))\n    }\n\n    /// Parse Dask-SQL SHOW TABLES [FROM] statement\n    fn parse_show_tables(&mut self) -> Result<DaskStatement, ParserError> {\n        if let Ok(obj_name) = &self.parser.parse_object_name() {\n            let (catalog_name, schema_name) = DaskParserUtils::elements_from_object_name(obj_name)?;\n            return Ok(DaskStatement::ShowTables(Box::new(ShowTables {\n                catalog_name,\n                schema_name: Some(schema_name),\n            })));\n        }\n        Ok(DaskStatement::ShowTables(Box::new(ShowTables {\n            catalog_name: None,\n            schema_name: None,\n        })))\n    }\n\n    /// Parse Dask-SQL SHOW COLUMNS FROM <table>\n    fn parse_show_columns(&mut self) -> Result<DaskStatement, ParserError> {\n        self.parser.expect_keyword(Keyword::FROM)?;\n        let (schema_name, table_name) =\n            DaskParserUtils::elements_from_object_name(&self.parser.parse_object_name()?)?;\n        Ok(DaskStatement::ShowColumns(Box::new(ShowColumns {\n            schema_name,\n            table_name,\n        })))\n    }\n\n    /// Parse Dask-SQL SHOW MODEL [FROM <schema>]\n    fn parse_show_models(&mut self) -> Result<DaskStatement, ParserError> {\n        let mut schema_name: Option<String> = None;\n        if !self.parser.consume_token(&Token::EOF) {\n            self.parser.expect_keyword(Keyword::FROM)?;\n            schema_name = Some(self.parser.parse_identifier()?.value);\n        }\n        Ok(DaskStatement::ShowModels(Box::new(ShowModels {\n            schema_name,\n        })))\n    }\n\n    /// Parse Dask-SQL ANALYZE TABLE <table>\n    fn parse_analyze_table(&mut self) -> Result<DaskStatement, ParserError> {\n        let obj_name = self.parser.parse_object_name()?;\n        self.parser\n            .expect_keywords(&[Keyword::COMPUTE, Keyword::STATISTICS, Keyword::FOR])?;\n        let (schema_name, table_name) = DaskParserUtils::elements_from_object_name(&obj_name)?;\n        let columns = match self\n            .parser\n            .parse_keywords(&[Keyword::ALL, Keyword::COLUMNS])\n        {\n            true => vec![],\n            false => {\n                self.parser.expect_keyword(Keyword::COLUMNS)?;\n                let mut values = vec![];\n                for select in self.parser.parse_projection()? {\n                    match select {\n                        SelectItem::UnnamedExpr(expr) => match expr {\n                            Expr::Identifier(ident) => values.push(ident.value),\n                            unexpected => {\n                                return parser_err!(format!(\n                                    \"Expected Identifier, found: {unexpected}\"\n                                ))\n                            }\n                        },\n                        unexpected => {\n                            return parser_err!(format!(\n                                \"Expected UnnamedExpr, found: {unexpected}\"\n                            ))\n                        }\n                    }\n                }\n                values\n            }\n        };\n        Ok(DaskStatement::AnalyzeTable(Box::new(AnalyzeTable {\n            schema_name,\n            table_name,\n            columns,\n        })))\n    }\n\n    fn parse_alter_table(&mut self) -> Result<DaskStatement, ParserError> {\n        let if_exists = self.parser.parse_keywords(&[Keyword::IF, Keyword::EXISTS]);\n\n        // parse fully qualified old table name\n        let (schema_name, old_table_name) =\n            DaskParserUtils::elements_from_object_name(&self.parser.parse_object_name()?)?;\n\n        self.parser\n            .expect_keywords(&[Keyword::RENAME, Keyword::TO])?;\n\n        // parse new table name\n        let new_table_name = self.parser.parse_identifier()?.value;\n\n        Ok(DaskStatement::AlterTable(Box::new(AlterTable {\n            old_table_name,\n            new_table_name,\n            schema_name,\n            if_exists,\n        })))\n    }\n\n    fn parse_alter_schema(&mut self) -> Result<DaskStatement, ParserError> {\n        // parse old schema name\n        let old_schema_name = self.parser.parse_identifier()?.value;\n\n        self.parser\n            .expect_keywords(&[Keyword::RENAME, Keyword::TO])?;\n\n        // parse new schema name\n        let new_schema_name = self.parser.parse_identifier()?.value;\n\n        Ok(DaskStatement::AlterSchema(Box::new(AlterSchema {\n            old_schema_name,\n            new_schema_name,\n        })))\n    }\n}\n\n#[cfg(test)]\nmod test {\n    use crate::parser::{DaskParser, DaskStatement};\n\n    #[test]\n    fn timestampadd() {\n        let sql = \"SELECT TIMESTAMPADD(YEAR, 2, d) FROM t\";\n        let statements = DaskParser::parse_sql(sql).unwrap();\n        assert_eq!(1, statements.len());\n        let actual = format!(\"{:?}\", statements[0]);\n        let expected = \"Statement(Query(Query { with: None, body: Select(Select { distinct: None, top: None, projection: [UnnamedExpr(Function(Function { name: ObjectName([Ident { value: \\\"timestampadd\\\", quote_style: None }]), args: [Unnamed(Expr(Value(SingleQuotedString(\\\"YEAR\\\")))), Unnamed(Expr(Value(Number(\\\"2\\\", false)))), Unnamed(Expr(Identifier(Ident { value: \\\"d\\\", quote_style: None })))], over: None, distinct: false, special: false, order_by: [] }))], into: None, from: [TableWithJoins { relation: Table { name: ObjectName([Ident { value: \\\"t\\\", quote_style: None }]), alias: None, args: None, with_hints: [], version: None, partitions: [] }, joins: [] }], lateral_views: [], selection: None, group_by: Expressions([]), cluster_by: [], distribute_by: [], sort_by: [], having: None, named_window: [], qualify: None }), order_by: [], limit: None, offset: None, fetch: None, locks: [] }))\";\n        assert!(actual.contains(expected));\n    }\n\n    #[test]\n    fn to_timestamp() {\n        let sql1 = \"SELECT TO_TIMESTAMP(d) FROM t\";\n        let statements1 = DaskParser::parse_sql(sql1).unwrap();\n        assert_eq!(1, statements1.len());\n        let actual1 = format!(\"{:?}\", statements1[0]);\n        let expected1 = \"Statement(Query(Query { with: None, body: Select(Select { distinct: None, top: None, projection: [UnnamedExpr(Function(Function { name: ObjectName([Ident { value: \\\"dsql_totimestamp\\\", quote_style: None }]), args: [Unnamed(Expr(Identifier(Ident { value: \\\"d\\\", quote_style: None }))), Unnamed(Expr(Value(SingleQuotedString(\\\"%Y-%m-%d %H:%M:%S\\\"))))], over: None, distinct: false, special: false, order_by: [] }))], into: None, from: [TableWithJoins { relation: Table { name: ObjectName([Ident { value: \\\"t\\\", quote_style: None }]), alias: None, args: None, with_hints: [], version: None, partitions: [] }, joins: [] }], lateral_views: [], selection: None, group_by: Expressions([]), cluster_by: [], distribute_by: [], sort_by: [], having: None, named_window: [], qualify: None }), order_by: [], limit: None, offset: None, fetch: None, locks: [] }))\";\n\n        assert!(actual1.contains(expected1));\n\n        let sql2 = \"SELECT TO_TIMESTAMP(d, \\\"%d/%m/%Y\\\") FROM t\";\n        let statements2 = DaskParser::parse_sql(sql2).unwrap();\n        assert_eq!(1, statements2.len());\n        let actual2 = format!(\"{:?}\", statements2[0]);\n        let expected2 = \"Statement(Query(Query { with: None, body: Select(Select { distinct: None, top: None, projection: [UnnamedExpr(Function(Function { name: ObjectName([Ident { value: \\\"dsql_totimestamp\\\", quote_style: None }]), args: [Unnamed(Expr(Identifier(Ident { value: \\\"d\\\", quote_style: None }))), Unnamed(Expr(Value(SingleQuotedString(\\\"\\\\\\\"%d/%m/%Y\\\\\\\"\\\"))))], over: None, distinct: false, special: false, order_by: [] }))], into: None, from: [TableWithJoins { relation: Table { name: ObjectName([Ident { value: \\\"t\\\", quote_style: None }]), alias: None, args: None, with_hints: [], version: None, partitions: [] }, joins: [] }], lateral_views: [], selection: None, group_by: Expressions([]), cluster_by: [], distribute_by: [], sort_by: [], having: None, named_window: [], qualify: None }), order_by: [], limit: None, offset: None, fetch: None, locks: [] }))\";\n\n        assert!(actual2.contains(expected2));\n    }\n\n    #[test]\n    fn create_model() {\n        let sql = r#\"CREATE MODEL my_model WITH (\n            model_class = 'mock.MagicMock',\n            target_column = 'target',\n            fit_kwargs = (\n                single_quoted_string = 'hello',\n                double_quoted_string = \"hi\",\n                integer = -300,\n                float = 23.45,\n                boolean = False,\n                array = ARRAY [ 1, 2 ],\n                dict = MAP [ 'a', 1 ],\n                set = MULTISET [ 1, 1, 2, 3 ]\n            )\n        ) AS (\n            SELECT x, y, x*y > 0 AS target\n            FROM timeseries\n            LIMIT 100\n        )\"#;\n        let statements = DaskParser::parse_sql(sql).unwrap();\n        assert_eq!(1, statements.len());\n\n        match &statements[0] {\n            DaskStatement::CreateModel(create_model) => {\n                let expected = \"[\\\n                    (\\\"model_class\\\", PySqlArg { expr: Some(Value(SingleQuotedString(\\\"mock.MagicMock\\\"))), custom: None }), \\\n                    (\\\"target_column\\\", PySqlArg { expr: Some(Value(SingleQuotedString(\\\"target\\\"))), custom: None }), \\\n                    (\\\"fit_kwargs\\\", PySqlArg { expr: None, custom: Some(Nested([\\\n                        (\\\"single_quoted_string\\\", PySqlArg { expr: Some(Value(SingleQuotedString(\\\"hello\\\"))), custom: None }), \\\n                        (\\\"double_quoted_string\\\", PySqlArg { expr: Some(Identifier(Ident { value: \\\"hi\\\", quote_style: Some('\\\"') })), custom: None }), \\\n                        (\\\"integer\\\", PySqlArg { expr: Some(UnaryOp { op: Minus, expr: Value(Number(\\\"300\\\", false)) }), custom: None }), \\\n                        (\\\"float\\\", PySqlArg { expr: Some(Value(Number(\\\"23.45\\\", false))), custom: None }), \\\n                        (\\\"boolean\\\", PySqlArg { expr: Some(Value(Boolean(false))), custom: None }), \\\n                        (\\\"array\\\", PySqlArg { expr: Some(Array(Array { elem: [Value(Number(\\\"1\\\", false)), Value(Number(\\\"2\\\", false))], named: true })), custom: None }), \\\n                        (\\\"dict\\\", PySqlArg { expr: None, custom: Some(Map([Value(SingleQuotedString(\\\"a\\\")), Value(Number(\\\"1\\\", false))])) }), \\\n                        (\\\"set\\\", PySqlArg { expr: None, custom: Some(Multiset([Value(Number(\\\"1\\\", false)), Value(Number(\\\"1\\\", false)), Value(Number(\\\"2\\\", false)), Value(Number(\\\"3\\\", false))])) })\\\n                    ])) })\\\n                ]\";\n                assert_eq!(expected, &format!(\"{:?}\", create_model.with_options));\n            }\n            _ => panic!(),\n        }\n    }\n}\n"
  },
  {
    "path": "src/sql/column.rs",
    "content": "use datafusion_python::datafusion_common::Column;\nuse pyo3::prelude::*;\n\n#[pyclass(name = \"Column\", module = \"dask_sql\", subclass)]\n#[derive(Debug, Clone)]\npub struct PyColumn {\n    /// Original Column instance\n    pub(crate) column: Column,\n}\n\nimpl From<PyColumn> for Column {\n    fn from(column: PyColumn) -> Column {\n        column.column\n    }\n}\n\nimpl From<Column> for PyColumn {\n    fn from(column: Column) -> PyColumn {\n        PyColumn { column }\n    }\n}\n\n#[pymethods]\nimpl PyColumn {\n    #[pyo3(name = \"getRelation\")]\n    pub fn relation(&self) -> String {\n        self.column.relation.clone().unwrap().to_string()\n    }\n\n    #[pyo3(name = \"getName\")]\n    pub fn name(&self) -> String {\n        self.column.name.clone()\n    }\n}\n"
  },
  {
    "path": "src/sql/exceptions.rs",
    "content": "use std::fmt::Debug;\n\nuse pyo3::{create_exception, PyErr};\n\n// Identifies exceptions that occur while attempting to generate a `LogicalPlan` from a SQL string\ncreate_exception!(rust, ParsingException, pyo3::exceptions::PyException);\n\n// Identifies exceptions that occur during attempts to optimization an existing `LogicalPlan`\ncreate_exception!(rust, OptimizationException, pyo3::exceptions::PyException);\n\npub fn py_type_err(e: impl Debug) -> PyErr {\n    PyErr::new::<pyo3::exceptions::PyTypeError, _>(format!(\"{e:?}\"))\n}\n\npub fn py_runtime_err(e: impl Debug) -> PyErr {\n    PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(\"{e:?}\"))\n}\n\npub fn py_parsing_exp(e: impl Debug) -> PyErr {\n    PyErr::new::<ParsingException, _>(format!(\"{e:?}\"))\n}\n\npub fn py_optimization_exp(e: impl Debug) -> PyErr {\n    PyErr::new::<OptimizationException, _>(format!(\"{e:?}\"))\n}\n"
  },
  {
    "path": "src/sql/function.rs",
    "content": "use std::collections::HashMap;\n\nuse datafusion_python::datafusion::arrow::datatypes::DataType;\nuse pyo3::prelude::*;\n\nuse super::types::PyDataType;\n\n#[pyclass(name = \"DaskFunction\", module = \"dask_sql\", subclass)]\n#[derive(Debug, Clone)]\npub struct DaskFunction {\n    #[pyo3(get, set)]\n    pub(crate) name: String,\n    pub(crate) return_types: HashMap<Vec<DataType>, DataType>,\n    pub(crate) aggregation: bool,\n}\n\nimpl DaskFunction {\n    pub fn new(\n        function_name: String,\n        input_types: Vec<PyDataType>,\n        return_type: PyDataType,\n        aggregation_bool: bool,\n    ) -> Self {\n        let mut func = Self {\n            name: function_name,\n            return_types: HashMap::new(),\n            aggregation: aggregation_bool,\n        };\n        func.add_type_mapping(input_types, return_type);\n        func\n    }\n\n    pub fn add_type_mapping(&mut self, input_types: Vec<PyDataType>, return_type: PyDataType) {\n        self.return_types.insert(\n            input_types.iter().map(|t| t.clone().into()).collect(),\n            return_type.into(),\n        );\n    }\n}\n"
  },
  {
    "path": "src/sql/logical/aggregate.rs",
    "content": "use datafusion_python::datafusion_expr::{\n    expr::{AggregateFunction, AggregateUDF, Alias},\n    logical_plan::{Aggregate, Distinct},\n    Expr,\n    LogicalPlan,\n};\nuse pyo3::prelude::*;\n\nuse crate::{\n    expression::{py_expr_list, PyExpr},\n    sql::exceptions::py_type_err,\n};\n\n#[pyclass(name = \"Aggregate\", module = \"dask_sql\", subclass)]\n#[derive(Clone)]\npub struct PyAggregate {\n    aggregate: Option<Aggregate>,\n    distinct: Option<Distinct>,\n}\n\n#[pymethods]\nimpl PyAggregate {\n    /// Determine the PyExprs that should be \"Distinct-ed\"\n    #[pyo3(name = \"getDistinctColumns\")]\n    pub fn distinct_columns(&self) -> PyResult<Vec<String>> {\n        match &self.distinct {\n            Some(e) => Ok(e.input.schema().field_names()),\n            None => Err(py_type_err(\n                \"distinct_columns invoked for non distinct instance\",\n            )),\n        }\n    }\n\n    /// Returns a Vec of the group expressions\n    #[pyo3(name = \"getGroupSets\")]\n    pub fn group_expressions(&self) -> PyResult<Vec<PyExpr>> {\n        match &self.aggregate {\n            Some(e) => py_expr_list(&e.input, &e.group_expr),\n            None => Ok(vec![]),\n        }\n    }\n\n    /// Returns the inner Aggregate Expr(s)\n    #[pyo3(name = \"getNamedAggCalls\")]\n    pub fn agg_expressions(&self) -> PyResult<Vec<PyExpr>> {\n        match &self.aggregate {\n            Some(e) => py_expr_list(&e.input, &e.aggr_expr),\n            None => Ok(vec![]),\n        }\n    }\n\n    #[pyo3(name = \"getAggregationFuncName\")]\n    pub fn agg_func_name(&self, expr: PyExpr) -> PyResult<String> {\n        _agg_func_name(&expr.expr)\n    }\n\n    #[pyo3(name = \"getArgs\")]\n    pub fn aggregation_arguments(&self, expr: PyExpr) -> PyResult<Vec<PyExpr>> {\n        self._aggregation_arguments(&expr.expr)\n    }\n\n    #[pyo3(name = \"isAggExprDistinct\")]\n    pub fn distinct_agg_expr(&self, expr: PyExpr) -> PyResult<bool> {\n        _distinct_agg_expr(&expr.expr)\n    }\n\n    #[pyo3(name = \"isDistinctNode\")]\n    pub fn distinct_node(&self) -> PyResult<bool> {\n        Ok(self.distinct.is_some())\n    }\n}\n\nimpl PyAggregate {\n    fn _aggregation_arguments(&self, expr: &Expr) -> PyResult<Vec<PyExpr>> {\n        match expr {\n            Expr::Alias(Alias { expr, .. }) => self._aggregation_arguments(expr.as_ref()),\n            Expr::AggregateFunction(AggregateFunction { fun: _, args, .. })\n            | Expr::AggregateUDF(AggregateUDF { fun: _, args, .. }) => match &self.aggregate {\n                Some(e) => py_expr_list(&e.input, args),\n                None => Ok(vec![]),\n            },\n            _ => Err(py_type_err(\n                \"Encountered a non Aggregate type in aggregation_arguments\",\n            )),\n        }\n    }\n}\n\nfn _agg_func_name(expr: &Expr) -> PyResult<String> {\n    match expr {\n        Expr::Alias(Alias { expr, .. }) => _agg_func_name(expr.as_ref()),\n        Expr::AggregateFunction(AggregateFunction { fun, .. }) => Ok(fun.to_string()),\n        Expr::AggregateUDF(AggregateUDF { fun, .. }) => Ok(fun.name.clone()),\n        _ => Err(py_type_err(\n            \"Encountered a non Aggregate type in agg_func_name\",\n        )),\n    }\n}\n\nfn _distinct_agg_expr(expr: &Expr) -> PyResult<bool> {\n    match expr {\n        Expr::Alias(Alias { expr, .. }) => _distinct_agg_expr(expr.as_ref()),\n        Expr::AggregateFunction(AggregateFunction { distinct, .. }) => Ok(*distinct),\n        Expr::AggregateUDF { .. } => {\n            // DataFusion does not support DISTINCT in UDAFs\n            Ok(false)\n        }\n        _ => Err(py_type_err(\n            \"Encountered a non Aggregate type in distinct_agg_expr\",\n        )),\n    }\n}\n\nimpl TryFrom<LogicalPlan> for PyAggregate {\n    type Error = PyErr;\n\n    fn try_from(logical_plan: LogicalPlan) -> Result<Self, Self::Error> {\n        match logical_plan {\n            LogicalPlan::Aggregate(aggregate) => Ok(PyAggregate {\n                aggregate: Some(aggregate),\n                distinct: None,\n            }),\n            LogicalPlan::Distinct(distinct) => Ok(PyAggregate {\n                aggregate: None,\n                distinct: Some(distinct),\n            }),\n            _ => Err(py_type_err(\"unexpected plan\")),\n        }\n    }\n}\n"
  },
  {
    "path": "src/sql/logical/alter_schema.rs",
    "content": "use std::{\n    any::Any,\n    fmt,\n    hash::{Hash, Hasher},\n    sync::Arc,\n};\n\nuse datafusion_python::{\n    datafusion_common::{DFSchema, DFSchemaRef},\n    datafusion_expr::{\n        logical_plan::{Extension, UserDefinedLogicalNode},\n        Expr,\n        LogicalPlan,\n    },\n};\nuse fmt::Debug;\nuse pyo3::prelude::*;\n\nuse crate::sql::{exceptions::py_type_err, logical};\n\n#[derive(Clone, PartialEq)]\npub struct AlterSchemaPlanNode {\n    pub schema: DFSchemaRef,\n    pub old_schema_name: String,\n    pub new_schema_name: String,\n}\n\nimpl Debug for AlterSchemaPlanNode {\n    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {\n        self.fmt_for_explain(f)\n    }\n}\n\nimpl Hash for AlterSchemaPlanNode {\n    fn hash<H: Hasher>(&self, state: &mut H) {\n        self.schema.hash(state);\n        self.old_schema_name.hash(state);\n        self.new_schema_name.hash(state);\n    }\n}\n\nimpl UserDefinedLogicalNode for AlterSchemaPlanNode {\n    fn as_any(&self) -> &dyn Any {\n        self\n    }\n\n    fn inputs(&self) -> Vec<&LogicalPlan> {\n        vec![]\n    }\n\n    fn schema(&self) -> &DFSchemaRef {\n        &self.schema\n    }\n\n    fn expressions(&self) -> Vec<Expr> {\n        // there is no need to expose any expressions here since DataFusion would\n        // not be able to do anything with expressions that are specific to\n        // ALTER SCHEMA {table_name}\n        vec![]\n    }\n\n    fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result {\n        write!(\n            f,\n            \"Alter Schema: old_schema_name: {:?}, new_schema_name: {:?}\",\n            self.old_schema_name, self.new_schema_name\n        )\n    }\n\n    fn from_template(\n        &self,\n        _exprs: &[Expr],\n        _inputs: &[LogicalPlan],\n    ) -> Arc<dyn UserDefinedLogicalNode> {\n        Arc::new(AlterSchemaPlanNode {\n            schema: Arc::new(DFSchema::empty()),\n            old_schema_name: self.old_schema_name.clone(),\n            new_schema_name: self.new_schema_name.clone(),\n        })\n    }\n\n    fn name(&self) -> &str {\n        \"AlterSchema\"\n    }\n\n    fn dyn_hash(&self, state: &mut dyn Hasher) {\n        let mut s = state;\n        self.hash(&mut s);\n    }\n\n    fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool {\n        match other.as_any().downcast_ref::<Self>() {\n            Some(o) => self == o,\n            None => false,\n        }\n    }\n}\n\n#[pyclass(name = \"AlterSchema\", module = \"dask_sql\", subclass)]\npub struct PyAlterSchema {\n    pub(crate) alter_schema: AlterSchemaPlanNode,\n}\n\n#[pymethods]\nimpl PyAlterSchema {\n    #[pyo3(name = \"getOldSchemaName\")]\n    fn get_old_schema_name(&self) -> PyResult<String> {\n        Ok(self.alter_schema.old_schema_name.clone())\n    }\n\n    #[pyo3(name = \"getNewSchemaName\")]\n    fn get_new_schema_name(&self) -> PyResult<String> {\n        Ok(self.alter_schema.new_schema_name.clone())\n    }\n}\n\nimpl TryFrom<logical::LogicalPlan> for PyAlterSchema {\n    type Error = PyErr;\n\n    fn try_from(logical_plan: logical::LogicalPlan) -> Result<Self, Self::Error> {\n        match logical_plan {\n            LogicalPlan::Extension(Extension { node })\n                if node\n                    .as_any()\n                    .downcast_ref::<AlterSchemaPlanNode>()\n                    .is_some() =>\n            {\n                let ext = node\n                    .as_any()\n                    .downcast_ref::<AlterSchemaPlanNode>()\n                    .expect(\"AlterSchemaPlanNode\");\n                Ok(PyAlterSchema {\n                    alter_schema: ext.clone(),\n                })\n            }\n            _ => Err(py_type_err(\"unexpected plan\")),\n        }\n    }\n}\n"
  },
  {
    "path": "src/sql/logical/alter_table.rs",
    "content": "use std::{\n    any::Any,\n    fmt,\n    hash::{Hash, Hasher},\n    sync::Arc,\n};\n\nuse datafusion_python::{\n    datafusion_common::{DFSchema, DFSchemaRef},\n    datafusion_expr::{\n        logical_plan::{Extension, UserDefinedLogicalNode},\n        Expr,\n        LogicalPlan,\n    },\n};\nuse fmt::Debug;\nuse pyo3::prelude::*;\n\nuse crate::sql::{exceptions::py_type_err, logical};\n\n#[derive(Clone, PartialEq)]\npub struct AlterTablePlanNode {\n    pub schema: DFSchemaRef,\n    pub old_table_name: String,\n    pub new_table_name: String,\n    pub schema_name: Option<String>,\n    pub if_exists: bool,\n}\n\nimpl Debug for AlterTablePlanNode {\n    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {\n        self.fmt_for_explain(f)\n    }\n}\n\nimpl Hash for AlterTablePlanNode {\n    fn hash<H: Hasher>(&self, state: &mut H) {\n        self.schema.hash(state);\n        self.old_table_name.hash(state);\n        self.new_table_name.hash(state);\n        self.schema_name.hash(state);\n        self.if_exists.hash(state);\n    }\n}\n\nimpl UserDefinedLogicalNode for AlterTablePlanNode {\n    fn as_any(&self) -> &dyn Any {\n        self\n    }\n\n    fn inputs(&self) -> Vec<&LogicalPlan> {\n        vec![]\n    }\n\n    fn schema(&self) -> &DFSchemaRef {\n        &self.schema\n    }\n\n    fn expressions(&self) -> Vec<Expr> {\n        // there is no need to expose any expressions here since DataFusion would\n        // not be able to do anything with expressions that are specific to\n        // ALTER TABLE {table_name}\n        vec![]\n    }\n\n    fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result {\n        write!(\n            f,\n            \"Alter Table: old_table_name: {:?}, new_table_name: {:?}, schema_name: {:?}\",\n            self.old_table_name, self.new_table_name, self.schema_name\n        )\n    }\n\n    fn from_template(\n        &self,\n        _exprs: &[Expr],\n        _inputs: &[LogicalPlan],\n    ) -> Arc<dyn UserDefinedLogicalNode> {\n        Arc::new(AlterTablePlanNode {\n            schema: Arc::new(DFSchema::empty()),\n            old_table_name: self.old_table_name.clone(),\n            new_table_name: self.new_table_name.clone(),\n            schema_name: self.schema_name.clone(),\n            if_exists: self.if_exists,\n        })\n    }\n\n    fn name(&self) -> &str {\n        \"AlterTable\"\n    }\n\n    fn dyn_hash(&self, state: &mut dyn Hasher) {\n        let mut s = state;\n        self.hash(&mut s);\n    }\n\n    fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool {\n        match other.as_any().downcast_ref::<Self>() {\n            Some(o) => self == o,\n            None => false,\n        }\n    }\n}\n\n#[pyclass(name = \"AlterTable\", module = \"dask_sql\", subclass)]\npub struct PyAlterTable {\n    pub(crate) alter_table: AlterTablePlanNode,\n}\n\n#[pymethods]\nimpl PyAlterTable {\n    #[pyo3(name = \"getOldTableName\")]\n    fn get_old_table_name(&self) -> PyResult<String> {\n        Ok(self.alter_table.old_table_name.clone())\n    }\n\n    #[pyo3(name = \"getNewTableName\")]\n    fn get_new_table_name(&self) -> PyResult<String> {\n        Ok(self.alter_table.new_table_name.clone())\n    }\n\n    #[pyo3(name = \"getSchemaName\")]\n    fn get_schema_name(&self) -> PyResult<Option<String>> {\n        Ok(self.alter_table.schema_name.clone())\n    }\n\n    #[pyo3(name = \"getIfExists\")]\n    fn get_if_exists(&self) -> PyResult<bool> {\n        Ok(self.alter_table.if_exists)\n    }\n}\n\nimpl TryFrom<logical::LogicalPlan> for PyAlterTable {\n    type Error = PyErr;\n\n    fn try_from(logical_plan: logical::LogicalPlan) -> Result<Self, Self::Error> {\n        match logical_plan {\n            LogicalPlan::Extension(Extension { node })\n                if node.as_any().downcast_ref::<AlterTablePlanNode>().is_some() =>\n            {\n                let ext = node\n                    .as_any()\n                    .downcast_ref::<AlterTablePlanNode>()\n                    .expect(\"AlterTablePlanNode\");\n                Ok(PyAlterTable {\n                    alter_table: ext.clone(),\n                })\n            }\n            _ => Err(py_type_err(\"unexpected plan\")),\n        }\n    }\n}\n"
  },
  {
    "path": "src/sql/logical/analyze_table.rs",
    "content": "use std::{\n    any::Any,\n    fmt,\n    hash::{Hash, Hasher},\n    sync::Arc,\n};\n\nuse datafusion_python::{\n    datafusion_common::{DFSchema, DFSchemaRef},\n    datafusion_expr::{\n        logical_plan::{Extension, UserDefinedLogicalNode},\n        Expr,\n        LogicalPlan,\n    },\n};\nuse fmt::Debug;\nuse pyo3::prelude::*;\n\nuse crate::sql::{exceptions::py_type_err, logical};\n\n#[derive(Clone, PartialEq)]\npub struct AnalyzeTablePlanNode {\n    pub schema: DFSchemaRef,\n    pub table_name: String,\n    pub schema_name: Option<String>,\n    pub columns: Vec<String>,\n}\n\nimpl Debug for AnalyzeTablePlanNode {\n    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {\n        self.fmt_for_explain(f)\n    }\n}\n\nimpl Hash for AnalyzeTablePlanNode {\n    fn hash<H: Hasher>(&self, state: &mut H) {\n        self.schema.hash(state);\n        self.table_name.hash(state);\n        self.schema_name.hash(state);\n        self.columns.hash(state);\n    }\n}\n\nimpl UserDefinedLogicalNode for AnalyzeTablePlanNode {\n    fn as_any(&self) -> &dyn Any {\n        self\n    }\n\n    fn inputs(&self) -> Vec<&LogicalPlan> {\n        vec![]\n    }\n\n    fn schema(&self) -> &DFSchemaRef {\n        &self.schema\n    }\n\n    fn expressions(&self) -> Vec<Expr> {\n        // there is no need to expose any expressions here since DataFusion would\n        // not be able to do anything with expressions that are specific to\n        // ANALYZE TABLE {table_name}\n        vec![]\n    }\n\n    fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result {\n        write!(\n            f,\n            \"Analyze Table: table_name: {:?}, columns: {:?}\",\n            self.table_name, self.columns\n        )\n    }\n\n    fn from_template(\n        &self,\n        _exprs: &[Expr],\n        _inputs: &[LogicalPlan],\n    ) -> Arc<dyn UserDefinedLogicalNode> {\n        Arc::new(AnalyzeTablePlanNode {\n            schema: Arc::new(DFSchema::empty()),\n            table_name: self.table_name.clone(),\n            schema_name: self.schema_name.clone(),\n            columns: self.columns.clone(),\n        })\n    }\n\n    fn name(&self) -> &str {\n        \"AnalyzeTable\"\n    }\n\n    fn dyn_hash(&self, state: &mut dyn Hasher) {\n        let mut s = state;\n        self.hash(&mut s);\n    }\n\n    fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool {\n        match other.as_any().downcast_ref::<Self>() {\n            Some(o) => self == o,\n            None => false,\n        }\n    }\n}\n\n#[pyclass(name = \"AnalyzeTable\", module = \"dask_sql\", subclass)]\npub struct PyAnalyzeTable {\n    pub(crate) analyze_table: AnalyzeTablePlanNode,\n}\n\n#[pymethods]\nimpl PyAnalyzeTable {\n    #[pyo3(name = \"getTableName\")]\n    fn get_table_name(&self) -> PyResult<String> {\n        Ok(self.analyze_table.table_name.clone())\n    }\n\n    #[pyo3(name = \"getSchemaName\")]\n    fn get_schema_name(&self) -> PyResult<Option<String>> {\n        Ok(self.analyze_table.schema_name.clone())\n    }\n\n    #[pyo3(name = \"getColumns\")]\n    fn get_columns(&self) -> PyResult<Vec<String>> {\n        Ok(self.analyze_table.columns.clone())\n    }\n}\n\nimpl TryFrom<logical::LogicalPlan> for PyAnalyzeTable {\n    type Error = PyErr;\n\n    fn try_from(logical_plan: logical::LogicalPlan) -> Result<Self, Self::Error> {\n        match logical_plan {\n            LogicalPlan::Extension(Extension { node })\n                if node\n                    .as_any()\n                    .downcast_ref::<AnalyzeTablePlanNode>()\n                    .is_some() =>\n            {\n                let ext = node\n                    .as_any()\n                    .downcast_ref::<AnalyzeTablePlanNode>()\n                    .expect(\"AnalyzeTablePlanNode\");\n                Ok(PyAnalyzeTable {\n                    analyze_table: ext.clone(),\n                })\n            }\n            _ => Err(py_type_err(\"unexpected plan\")),\n        }\n    }\n}\n"
  },
  {
    "path": "src/sql/logical/create_catalog_schema.rs",
    "content": "use std::{\n    any::Any,\n    fmt,\n    hash::{Hash, Hasher},\n    sync::Arc,\n};\n\nuse datafusion_python::{\n    datafusion_common::{DFSchema, DFSchemaRef},\n    datafusion_expr::{logical_plan::UserDefinedLogicalNode, Expr, LogicalPlan},\n};\nuse fmt::Debug;\nuse pyo3::prelude::*;\n\nuse crate::sql::{exceptions::py_type_err, logical};\n\n#[derive(Clone, PartialEq)]\npub struct CreateCatalogSchemaPlanNode {\n    pub schema: DFSchemaRef,\n    pub schema_name: String,\n    pub if_not_exists: bool,\n    pub or_replace: bool,\n}\n\nimpl Debug for CreateCatalogSchemaPlanNode {\n    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {\n        self.fmt_for_explain(f)\n    }\n}\n\nimpl Hash for CreateCatalogSchemaPlanNode {\n    fn hash<H: Hasher>(&self, state: &mut H) {\n        self.schema.hash(state);\n        self.schema_name.hash(state);\n        self.if_not_exists.hash(state);\n        self.or_replace.hash(state);\n    }\n}\n\nimpl UserDefinedLogicalNode for CreateCatalogSchemaPlanNode {\n    fn as_any(&self) -> &dyn Any {\n        self\n    }\n\n    fn inputs(&self) -> Vec<&LogicalPlan> {\n        vec![]\n    }\n\n    fn schema(&self) -> &DFSchemaRef {\n        &self.schema\n    }\n\n    fn expressions(&self) -> Vec<Expr> {\n        // there is no need to expose any expressions here since DataFusion would\n        // not be able to do anything with expressions that are specific to\n        // CREATE SCHEMA\n        vec![]\n    }\n\n    fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result {\n        write!(\n            f,\n            \"CreateCatalogSchema: schema_name={}, or_replace={}, if_not_exists={}\",\n            self.schema_name, self.or_replace, self.if_not_exists\n        )\n    }\n\n    fn from_template(\n        &self,\n        _exprs: &[Expr],\n        _inputs: &[LogicalPlan],\n    ) -> Arc<dyn UserDefinedLogicalNode> {\n        Arc::new(CreateCatalogSchemaPlanNode {\n            schema: Arc::new(DFSchema::empty()),\n            schema_name: self.schema_name.clone(),\n            if_not_exists: self.if_not_exists,\n            or_replace: self.or_replace,\n        })\n    }\n\n    fn name(&self) -> &str {\n        \"CreateCatalogSchema\"\n    }\n\n    fn dyn_hash(&self, state: &mut dyn Hasher) {\n        let mut s = state;\n        self.hash(&mut s);\n    }\n\n    fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool {\n        match other.as_any().downcast_ref::<Self>() {\n            Some(o) => self == o,\n            None => false,\n        }\n    }\n}\n\n#[pyclass(name = \"CreateCatalogSchema\", module = \"dask_sql\", subclass)]\npub struct PyCreateCatalogSchema {\n    pub(crate) create_catalog_schema: CreateCatalogSchemaPlanNode,\n}\n\n#[pymethods]\nimpl PyCreateCatalogSchema {\n    #[pyo3(name = \"getSchemaName\")]\n    fn get_schema_name(&self) -> PyResult<String> {\n        Ok(self.create_catalog_schema.schema_name.clone())\n    }\n\n    #[pyo3(name = \"getIfNotExists\")]\n    fn get_if_not_exists(&self) -> PyResult<bool> {\n        Ok(self.create_catalog_schema.if_not_exists)\n    }\n\n    #[pyo3(name = \"getReplace\")]\n    fn get_replace(&self) -> PyResult<bool> {\n        Ok(self.create_catalog_schema.or_replace)\n    }\n}\n\nimpl TryFrom<logical::LogicalPlan> for PyCreateCatalogSchema {\n    type Error = PyErr;\n\n    fn try_from(logical_plan: logical::LogicalPlan) -> Result<Self, Self::Error> {\n        match logical_plan {\n            logical::LogicalPlan::Extension(extension) => {\n                if let Some(ext) = extension\n                    .node\n                    .as_any()\n                    .downcast_ref::<CreateCatalogSchemaPlanNode>()\n                {\n                    Ok(PyCreateCatalogSchema {\n                        create_catalog_schema: ext.clone(),\n                    })\n                } else {\n                    Err(py_type_err(\"unexpected plan\"))\n                }\n            }\n            _ => Err(py_type_err(\"unexpected plan\")),\n        }\n    }\n}\n"
  },
  {
    "path": "src/sql/logical/create_experiment.rs",
    "content": "use std::{\n    any::Any,\n    fmt,\n    hash::{Hash, Hasher},\n    sync::Arc,\n};\n\nuse datafusion_python::{\n    datafusion_common::DFSchemaRef,\n    datafusion_expr::{logical_plan::UserDefinedLogicalNode, Expr, LogicalPlan},\n};\nuse fmt::Debug;\nuse pyo3::prelude::*;\n\nuse crate::{\n    parser::PySqlArg,\n    sql::{exceptions::py_type_err, logical},\n};\n\n#[derive(Clone, PartialEq)]\npub struct CreateExperimentPlanNode {\n    pub schema_name: Option<String>,\n    pub experiment_name: String,\n    pub input: LogicalPlan,\n    pub if_not_exists: bool,\n    pub or_replace: bool,\n    pub with_options: Vec<(String, PySqlArg)>,\n}\n\nimpl Debug for CreateExperimentPlanNode {\n    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {\n        self.fmt_for_explain(f)\n    }\n}\n\nimpl Hash for CreateExperimentPlanNode {\n    fn hash<H: Hasher>(&self, state: &mut H) {\n        self.schema_name.hash(state);\n        self.experiment_name.hash(state);\n        self.input.hash(state);\n        self.if_not_exists.hash(state);\n        self.or_replace.hash(state);\n        // self.with_options.hash(state);\n    }\n}\n\nimpl UserDefinedLogicalNode for CreateExperimentPlanNode {\n    fn as_any(&self) -> &dyn Any {\n        self\n    }\n\n    fn inputs(&self) -> Vec<&LogicalPlan> {\n        vec![&self.input]\n    }\n\n    fn schema(&self) -> &DFSchemaRef {\n        self.input.schema()\n    }\n\n    fn expressions(&self) -> Vec<Expr> {\n        // there is no need to expose any expressions here since DataFusion would\n        // not be able to do anything with expressions that are specific to\n        // CREATE EXPERIMENT\n        vec![]\n    }\n\n    fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result {\n        write!(\n            f,\n            \"CreateExperiment: experiment_name={}\",\n            self.experiment_name\n        )\n    }\n\n    fn from_template(\n        &self,\n        _exprs: &[Expr],\n        inputs: &[LogicalPlan],\n    ) -> Arc<dyn UserDefinedLogicalNode> {\n        assert_eq!(inputs.len(), 1, \"input size inconsistent\");\n        Arc::new(CreateExperimentPlanNode {\n            schema_name: self.schema_name.clone(),\n            experiment_name: self.experiment_name.clone(),\n            input: inputs[0].clone(),\n            if_not_exists: self.if_not_exists,\n            or_replace: self.or_replace,\n            with_options: self.with_options.clone(),\n        })\n    }\n\n    fn name(&self) -> &str {\n        \"CreateExperiment\"\n    }\n\n    fn dyn_hash(&self, state: &mut dyn Hasher) {\n        let mut s = state;\n        self.hash(&mut s);\n    }\n\n    fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool {\n        match other.as_any().downcast_ref::<Self>() {\n            Some(o) => self == o,\n            None => false,\n        }\n    }\n}\n\n#[pyclass(name = \"CreateExperiment\", module = \"dask_sql\", subclass)]\npub struct PyCreateExperiment {\n    pub(crate) create_experiment: CreateExperimentPlanNode,\n}\n\n#[pymethods]\nimpl PyCreateExperiment {\n    /// Creating an experiment requires that a subquery be passed to the CREATE EXPERIMENT\n    /// statement to be used to gather the dataset which should be used for the\n    /// experiment. This function returns that portion of the statement.\n    #[pyo3(name = \"getSelectQuery\")]\n    fn get_select_query(&self) -> PyResult<logical::PyLogicalPlan> {\n        Ok(self.create_experiment.input.clone().into())\n    }\n\n    #[pyo3(name = \"getSchemaName\")]\n    fn get_schema_name(&self) -> PyResult<Option<String>> {\n        Ok(self.create_experiment.schema_name.clone())\n    }\n\n    #[pyo3(name = \"getExperimentName\")]\n    fn get_experiment_name(&self) -> PyResult<String> {\n        Ok(self.create_experiment.experiment_name.clone())\n    }\n\n    #[pyo3(name = \"getIfNotExists\")]\n    fn get_if_not_exists(&self) -> PyResult<bool> {\n        Ok(self.create_experiment.if_not_exists)\n    }\n\n    #[pyo3(name = \"getOrReplace\")]\n    pub fn get_or_replace(&self) -> PyResult<bool> {\n        Ok(self.create_experiment.or_replace)\n    }\n\n    #[pyo3(name = \"getSQLWithOptions\")]\n    fn sql_with_options(&self) -> PyResult<Vec<(String, PySqlArg)>> {\n        Ok(self.create_experiment.with_options.clone())\n    }\n}\n\nimpl TryFrom<logical::LogicalPlan> for PyCreateExperiment {\n    type Error = PyErr;\n\n    fn try_from(logical_plan: logical::LogicalPlan) -> Result<Self, Self::Error> {\n        match logical_plan {\n            logical::LogicalPlan::Extension(extension) => {\n                if let Some(ext) = extension\n                    .node\n                    .as_any()\n                    .downcast_ref::<CreateExperimentPlanNode>()\n                {\n                    Ok(PyCreateExperiment {\n                        create_experiment: ext.clone(),\n                    })\n                } else {\n                    Err(py_type_err(\"unexpected plan\"))\n                }\n            }\n            _ => Err(py_type_err(\"unexpected plan\")),\n        }\n    }\n}\n"
  },
  {
    "path": "src/sql/logical/create_memory_table.rs",
    "content": "use datafusion_python::datafusion_expr::{\n    logical_plan::{CreateMemoryTable, CreateView},\n    DdlStatement,\n    LogicalPlan,\n};\nuse pyo3::prelude::*;\n\nuse crate::sql::{exceptions::py_type_err, logical::PyLogicalPlan};\n\n#[pyclass(name = \"CreateMemoryTable\", module = \"dask_sql\", subclass)]\n#[derive(Clone)]\npub struct PyCreateMemoryTable {\n    create_memory_table: Option<CreateMemoryTable>,\n    create_view: Option<CreateView>,\n}\n\n#[pymethods]\nimpl PyCreateMemoryTable {\n    #[pyo3(name = \"getQualifiedName\")]\n    pub fn get_table_name(&self) -> PyResult<String> {\n        Ok(match &self.create_memory_table {\n            Some(create_memory_table) => create_memory_table.name.to_string(),\n            None => match &self.create_view {\n                Some(create_view) => create_view.name.to_string(),\n                None => {\n                    return Err(py_type_err(\n                        \"Encountered a non CreateMemoryTable/CreateView type in get_input\",\n                    ))\n                }\n            },\n        })\n    }\n\n    #[pyo3(name = \"getInput\")]\n    pub fn get_input(&self) -> PyResult<PyLogicalPlan> {\n        Ok(match &self.create_memory_table {\n            Some(create_memory_table) => PyLogicalPlan {\n                original_plan: (*create_memory_table.input).clone(),\n                current_node: None,\n            },\n            None => match &self.create_view {\n                Some(create_view) => PyLogicalPlan {\n                    original_plan: (*create_view.input).clone(),\n                    current_node: None,\n                },\n                None => {\n                    return Err(py_type_err(\n                        \"Encountered a non CreateMemoryTable/CreateView type in get_input\",\n                    ))\n                }\n            },\n        })\n    }\n\n    #[pyo3(name = \"getIfNotExists\")]\n    pub fn get_if_not_exists(&self) -> PyResult<bool> {\n        Ok(match &self.create_memory_table {\n            Some(create_memory_table) => create_memory_table.if_not_exists,\n            None => false, // TODO: in the future we may want to set this based on dialect\n        })\n    }\n\n    #[pyo3(name = \"getOrReplace\")]\n    pub fn get_or_replace(&self) -> PyResult<bool> {\n        Ok(match &self.create_memory_table {\n            Some(create_memory_table) => create_memory_table.or_replace,\n            None => match &self.create_view {\n                Some(create_view) => create_view.or_replace,\n                None => {\n                    return Err(py_type_err(\n                        \"Encountered a non CreateMemoryTable/CreateView type in get_input\",\n                    ))\n                }\n            },\n        })\n    }\n\n    #[pyo3(name = \"isTable\")]\n    pub fn is_table(&self) -> PyResult<bool> {\n        Ok(self.create_memory_table.is_some())\n    }\n}\n\nimpl TryFrom<LogicalPlan> for PyCreateMemoryTable {\n    type Error = PyErr;\n\n    fn try_from(logical_plan: LogicalPlan) -> Result<Self, Self::Error> {\n        Ok(match logical_plan {\n            LogicalPlan::Ddl(DdlStatement::CreateMemoryTable(cmt)) => PyCreateMemoryTable {\n                create_memory_table: Some(cmt),\n                create_view: None,\n            },\n            LogicalPlan::Ddl(DdlStatement::CreateView(cv)) => PyCreateMemoryTable {\n                create_memory_table: None,\n                create_view: Some(cv),\n            },\n            _ => return Err(py_type_err(\"unexpected plan\")),\n        })\n    }\n}\n"
  },
  {
    "path": "src/sql/logical/create_model.rs",
    "content": "use std::{\n    any::Any,\n    fmt,\n    hash::{Hash, Hasher},\n    sync::Arc,\n};\n\nuse datafusion_python::{\n    datafusion_common::DFSchemaRef,\n    datafusion_expr::{logical_plan::UserDefinedLogicalNode, Expr, LogicalPlan},\n};\nuse fmt::Debug;\nuse pyo3::prelude::*;\n\nuse crate::{\n    parser::PySqlArg,\n    sql::{exceptions::py_type_err, logical},\n};\n\n#[derive(Clone, PartialEq)]\npub struct CreateModelPlanNode {\n    pub schema_name: Option<String>,\n    pub model_name: String,\n    pub input: LogicalPlan,\n    pub if_not_exists: bool,\n    pub or_replace: bool,\n    pub with_options: Vec<(String, PySqlArg)>,\n}\n\nimpl Debug for CreateModelPlanNode {\n    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {\n        self.fmt_for_explain(f)\n    }\n}\n\nimpl Hash for CreateModelPlanNode {\n    fn hash<H: Hasher>(&self, state: &mut H) {\n        self.schema_name.hash(state);\n        self.model_name.hash(state);\n        self.input.hash(state);\n        self.if_not_exists.hash(state);\n        self.or_replace.hash(state);\n        // self.with_options.hash(state);\n    }\n}\n\nimpl UserDefinedLogicalNode for CreateModelPlanNode {\n    fn as_any(&self) -> &dyn Any {\n        self\n    }\n\n    fn inputs(&self) -> Vec<&LogicalPlan> {\n        vec![&self.input]\n    }\n\n    fn schema(&self) -> &DFSchemaRef {\n        self.input.schema()\n    }\n\n    fn expressions(&self) -> Vec<Expr> {\n        // there is no need to expose any expressions here since DataFusion would\n        // not be able to do anything with expressions that are specific to\n        // CREATE MODEL\n        vec![]\n    }\n\n    fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result {\n        write!(f, \"CreateModel: model_name={}\", self.model_name)\n    }\n\n    fn from_template(\n        &self,\n        _exprs: &[Expr],\n        inputs: &[LogicalPlan],\n    ) -> Arc<dyn UserDefinedLogicalNode> {\n        assert_eq!(inputs.len(), 1, \"input size inconsistent\");\n        Arc::new(CreateModelPlanNode {\n            schema_name: self.schema_name.clone(),\n            model_name: self.model_name.clone(),\n            input: inputs[0].clone(),\n            if_not_exists: self.if_not_exists,\n            or_replace: self.or_replace,\n            with_options: self.with_options.clone(),\n        })\n    }\n\n    fn name(&self) -> &str {\n        \"CreateModel\"\n    }\n\n    fn dyn_hash(&self, state: &mut dyn Hasher) {\n        let mut s = state;\n        self.hash(&mut s);\n    }\n\n    fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool {\n        match other.as_any().downcast_ref::<Self>() {\n            Some(o) => self == o,\n            None => false,\n        }\n    }\n}\n\n#[pyclass(name = \"CreateModel\", module = \"dask_sql\", subclass)]\npub struct PyCreateModel {\n    pub(crate) create_model: CreateModelPlanNode,\n}\n\n#[pymethods]\nimpl PyCreateModel {\n    /// Creating a model requires that a subquery be passed to the CREATE MODEL\n    /// statement to be used to gather the dataset which should be used for the\n    /// model. This function returns that portion of the statement.\n    #[pyo3(name = \"getSelectQuery\")]\n    fn get_select_query(&self) -> PyResult<logical::PyLogicalPlan> {\n        Ok(self.create_model.input.clone().into())\n    }\n\n    #[pyo3(name = \"getSchemaName\")]\n    fn get_schema_name(&self) -> PyResult<Option<String>> {\n        Ok(self.create_model.schema_name.clone())\n    }\n\n    #[pyo3(name = \"getModelName\")]\n    fn get_model_name(&self) -> PyResult<String> {\n        Ok(self.create_model.model_name.clone())\n    }\n\n    #[pyo3(name = \"getIfNotExists\")]\n    fn get_if_not_exists(&self) -> PyResult<bool> {\n        Ok(self.create_model.if_not_exists)\n    }\n\n    #[pyo3(name = \"getOrReplace\")]\n    pub fn get_or_replace(&self) -> PyResult<bool> {\n        Ok(self.create_model.or_replace)\n    }\n\n    #[pyo3(name = \"getSQLWithOptions\")]\n    fn sql_with_options(&self) -> PyResult<Vec<(String, PySqlArg)>> {\n        Ok(self.create_model.with_options.clone())\n    }\n}\n\nimpl TryFrom<logical::LogicalPlan> for PyCreateModel {\n    type Error = PyErr;\n\n    fn try_from(logical_plan: logical::LogicalPlan) -> Result<Self, Self::Error> {\n        match logical_plan {\n            logical::LogicalPlan::Extension(extension) => {\n                if let Some(ext) = extension\n                    .node\n                    .as_any()\n                    .downcast_ref::<CreateModelPlanNode>()\n                {\n                    Ok(PyCreateModel {\n                        create_model: ext.clone(),\n                    })\n                } else {\n                    Err(py_type_err(\"unexpected plan\"))\n                }\n            }\n            _ => Err(py_type_err(\"unexpected plan\")),\n        }\n    }\n}\n"
  },
  {
    "path": "src/sql/logical/create_table.rs",
    "content": "use std::{\n    any::Any,\n    fmt,\n    hash::{Hash, Hasher},\n    sync::Arc,\n};\n\nuse datafusion_python::{\n    datafusion_common::{DFSchema, DFSchemaRef},\n    datafusion_expr::{logical_plan::UserDefinedLogicalNode, Expr, LogicalPlan},\n};\nuse fmt::Debug;\nuse pyo3::prelude::*;\n\nuse crate::{\n    parser::PySqlArg,\n    sql::{exceptions::py_type_err, logical},\n};\n\n#[derive(Clone, PartialEq)]\npub struct CreateTablePlanNode {\n    pub schema: DFSchemaRef,\n    pub schema_name: Option<String>, // \"something\" in `something.table_name`\n    pub table_name: String,\n    pub if_not_exists: bool,\n    pub or_replace: bool,\n    pub with_options: Vec<(String, PySqlArg)>,\n}\n\nimpl Debug for CreateTablePlanNode {\n    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {\n        self.fmt_for_explain(f)\n    }\n}\n\nimpl Hash for CreateTablePlanNode {\n    fn hash<H: Hasher>(&self, state: &mut H) {\n        self.schema.hash(state);\n        self.schema_name.hash(state);\n        self.table_name.hash(state);\n        self.if_not_exists.hash(state);\n        self.or_replace.hash(state);\n        // self.with_options.hash(state);\n    }\n}\n\nimpl UserDefinedLogicalNode for CreateTablePlanNode {\n    fn as_any(&self) -> &dyn Any {\n        self\n    }\n\n    fn inputs(&self) -> Vec<&LogicalPlan> {\n        vec![]\n    }\n\n    fn schema(&self) -> &DFSchemaRef {\n        &self.schema\n    }\n\n    fn expressions(&self) -> Vec<Expr> {\n        // there is no need to expose any expressions here since DataFusion would\n        // not be able to do anything with expressions that are specific to\n        // CREATE TABLE\n        vec![]\n    }\n\n    fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result {\n        write!(f, \"CreateTable: table_name={}\", self.table_name)\n    }\n\n    fn from_template(\n        &self,\n        _exprs: &[Expr],\n        _inputs: &[LogicalPlan],\n    ) -> Arc<dyn UserDefinedLogicalNode> {\n        Arc::new(CreateTablePlanNode {\n            schema: Arc::new(DFSchema::empty()),\n            schema_name: self.schema_name.clone(),\n            table_name: self.table_name.clone(),\n            if_not_exists: self.if_not_exists,\n            or_replace: self.or_replace,\n            with_options: self.with_options.clone(),\n        })\n    }\n\n    fn name(&self) -> &str {\n        \"CreateTable\"\n    }\n\n    fn dyn_hash(&self, state: &mut dyn Hasher) {\n        let mut s = state;\n        self.hash(&mut s);\n    }\n\n    fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool {\n        match other.as_any().downcast_ref::<Self>() {\n            Some(o) => self == o,\n            None => false,\n        }\n    }\n}\n\n#[pyclass(name = \"CreateTable\", module = \"dask_sql\", subclass)]\npub struct PyCreateTable {\n    pub(crate) create_table: CreateTablePlanNode,\n}\n\n#[pymethods]\nimpl PyCreateTable {\n    #[pyo3(name = \"getSchemaName\")]\n    fn get_schema_name(&self) -> PyResult<Option<String>> {\n        Ok(self.create_table.schema_name.clone())\n    }\n\n    #[pyo3(name = \"getTableName\")]\n    fn get_table_name(&self) -> PyResult<String> {\n        Ok(self.create_table.table_name.clone())\n    }\n\n    #[pyo3(name = \"getIfNotExists\")]\n    fn get_if_not_exists(&self) -> PyResult<bool> {\n        Ok(self.create_table.if_not_exists)\n    }\n\n    #[pyo3(name = \"getOrReplace\")]\n    fn get_or_replace(&self) -> PyResult<bool> {\n        Ok(self.create_table.or_replace)\n    }\n\n    #[pyo3(name = \"getSQLWithOptions\")]\n    fn sql_with_options(&self) -> PyResult<Vec<(String, PySqlArg)>> {\n        Ok(self.create_table.with_options.clone())\n    }\n}\n\nimpl TryFrom<logical::LogicalPlan> for PyCreateTable {\n    type Error = PyErr;\n\n    fn try_from(logical_plan: logical::LogicalPlan) -> Result<Self, Self::Error> {\n        match logical_plan {\n            logical::LogicalPlan::Extension(extension) => {\n                if let Some(ext) = extension\n                    .node\n                    .as_any()\n                    .downcast_ref::<CreateTablePlanNode>()\n                {\n                    Ok(PyCreateTable {\n                        create_table: ext.clone(),\n                    })\n                } else {\n                    Err(py_type_err(\"unexpected plan\"))\n                }\n            }\n            _ => Err(py_type_err(\"unexpected plan\")),\n        }\n    }\n}\n"
  },
  {
    "path": "src/sql/logical/describe_model.rs",
    "content": "use std::{\n    any::Any,\n    fmt,\n    hash::{Hash, Hasher},\n    sync::Arc,\n};\n\nuse datafusion_python::{\n    datafusion_common::{DFSchema, DFSchemaRef},\n    datafusion_expr::{logical_plan::UserDefinedLogicalNode, Expr, LogicalPlan},\n};\nuse fmt::Debug;\nuse pyo3::prelude::*;\n\nuse crate::sql::{exceptions::py_type_err, logical};\n\n#[derive(Clone, PartialEq)]\npub struct DescribeModelPlanNode {\n    pub schema: DFSchemaRef,\n    pub schema_name: Option<String>,\n    pub model_name: String,\n}\n\nimpl Debug for DescribeModelPlanNode {\n    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {\n        self.fmt_for_explain(f)\n    }\n}\n\nimpl Hash for DescribeModelPlanNode {\n    fn hash<H: Hasher>(&self, state: &mut H) {\n        self.schema.hash(state);\n        self.schema_name.hash(state);\n        self.model_name.hash(state);\n    }\n}\n\nimpl UserDefinedLogicalNode for DescribeModelPlanNode {\n    fn as_any(&self) -> &dyn Any {\n        self\n    }\n\n    fn inputs(&self) -> Vec<&LogicalPlan> {\n        vec![]\n    }\n\n    fn schema(&self) -> &DFSchemaRef {\n        &self.schema\n    }\n\n    fn expressions(&self) -> Vec<Expr> {\n        // there is no need to expose any expressions here since DataFusion would\n        // not be able to do anything with expressions that are specific to\n        // DESCRIBE MODEL\n        vec![]\n    }\n\n    fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result {\n        write!(f, \"DescribeModel: model_name={}\", self.model_name)\n    }\n\n    fn from_template(\n        &self,\n        _exprs: &[Expr],\n        inputs: &[LogicalPlan],\n    ) -> Arc<dyn UserDefinedLogicalNode> {\n        assert_eq!(inputs.len(), 0, \"input size inconsistent\");\n        Arc::new(DescribeModelPlanNode {\n            schema: Arc::new(DFSchema::empty()),\n            schema_name: self.schema_name.clone(),\n            model_name: self.model_name.clone(),\n        })\n    }\n\n    fn name(&self) -> &str {\n        \"DescribeModel\"\n    }\n\n    fn dyn_hash(&self, state: &mut dyn Hasher) {\n        let mut s = state;\n        self.hash(&mut s);\n    }\n\n    fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool {\n        match other.as_any().downcast_ref::<Self>() {\n            Some(o) => self == o,\n            None => false,\n        }\n    }\n}\n\n#[pyclass(name = \"DescribeModel\", module = \"dask_sql\", subclass)]\npub struct PyDescribeModel {\n    pub(crate) describe_model: DescribeModelPlanNode,\n}\n\n#[pymethods]\nimpl PyDescribeModel {\n    #[pyo3(name = \"getSchemaName\")]\n    fn get_schema_name(&self) -> PyResult<Option<String>> {\n        Ok(self.describe_model.schema_name.clone())\n    }\n\n    #[pyo3(name = \"getModelName\")]\n    fn get_model_name(&self) -> PyResult<String> {\n        Ok(self.describe_model.model_name.clone())\n    }\n}\n\nimpl TryFrom<logical::LogicalPlan> for PyDescribeModel {\n    type Error = PyErr;\n\n    fn try_from(logical_plan: logical::LogicalPlan) -> Result<Self, Self::Error> {\n        match logical_plan {\n            logical::LogicalPlan::Extension(extension) => {\n                if let Some(ext) = extension\n                    .node\n                    .as_any()\n                    .downcast_ref::<DescribeModelPlanNode>()\n                {\n                    Ok(PyDescribeModel {\n                        describe_model: ext.clone(),\n                    })\n                } else {\n                    Err(py_type_err(\"unexpected plan\"))\n                }\n            }\n            _ => Err(py_type_err(\"unexpected plan\")),\n        }\n    }\n}\n"
  },
  {
    "path": "src/sql/logical/drop_model.rs",
    "content": "use std::{\n    any::Any,\n    fmt,\n    hash::{Hash, Hasher},\n    sync::Arc,\n};\n\nuse datafusion_python::{\n    datafusion_common::{DFSchema, DFSchemaRef},\n    datafusion_expr::{logical_plan::UserDefinedLogicalNode, Expr, LogicalPlan},\n};\nuse fmt::Debug;\nuse pyo3::prelude::*;\n\nuse crate::sql::{exceptions::py_type_err, logical};\n\n#[derive(Clone, PartialEq)]\npub struct DropModelPlanNode {\n    pub schema_name: Option<String>,\n    pub model_name: String,\n    pub if_exists: bool,\n    pub schema: DFSchemaRef,\n}\n\nimpl Debug for DropModelPlanNode {\n    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {\n        self.fmt_for_explain(f)\n    }\n}\n\nimpl Hash for DropModelPlanNode {\n    fn hash<H: Hasher>(&self, state: &mut H) {\n        self.schema_name.hash(state);\n        self.model_name.hash(state);\n        self.if_exists.hash(state);\n        self.schema.hash(state);\n    }\n}\n\nimpl UserDefinedLogicalNode for DropModelPlanNode {\n    fn as_any(&self) -> &dyn Any {\n        self\n    }\n\n    fn inputs(&self) -> Vec<&LogicalPlan> {\n        vec![]\n    }\n\n    fn schema(&self) -> &DFSchemaRef {\n        &self.schema\n    }\n\n    fn expressions(&self) -> Vec<Expr> {\n        // there is no need to expose any expressions here since DataFusion would\n        // not be able to do anything with expressions that are specific to\n        // DROP MODEL\n        vec![]\n    }\n\n    fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result {\n        write!(f, \"DropModel: model_name={}\", self.model_name)\n    }\n\n    fn from_template(\n        &self,\n        _exprs: &[Expr],\n        inputs: &[LogicalPlan],\n    ) -> Arc<dyn UserDefinedLogicalNode> {\n        assert_eq!(inputs.len(), 0, \"input size inconsistent\");\n        Arc::new(DropModelPlanNode {\n            schema_name: self.schema_name.clone(),\n            model_name: self.model_name.clone(),\n            if_exists: self.if_exists,\n            schema: Arc::new(DFSchema::empty()),\n        })\n    }\n\n    fn name(&self) -> &str {\n        \"DropModel\"\n    }\n\n    fn dyn_hash(&self, state: &mut dyn Hasher) {\n        let mut s = state;\n        self.hash(&mut s);\n    }\n\n    fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool {\n        match other.as_any().downcast_ref::<Self>() {\n            Some(o) => self == o,\n            None => false,\n        }\n    }\n}\n\n#[pyclass(name = \"DropModel\", module = \"dask_sql\", subclass)]\npub struct PyDropModel {\n    pub(crate) drop_model: DropModelPlanNode,\n}\n\n#[pymethods]\nimpl PyDropModel {\n    #[pyo3(name = \"getSchemaName\")]\n    fn get_schema_name(&self) -> PyResult<Option<String>> {\n        Ok(self.drop_model.schema_name.clone())\n    }\n\n    #[pyo3(name = \"getModelName\")]\n    fn get_model_name(&self) -> PyResult<String> {\n        Ok(self.drop_model.model_name.clone())\n    }\n\n    #[pyo3(name = \"getIfExists\")]\n    pub fn get_if_exists(&self) -> PyResult<bool> {\n        Ok(self.drop_model.if_exists)\n    }\n}\n\nimpl TryFrom<logical::LogicalPlan> for PyDropModel {\n    type Error = PyErr;\n\n    fn try_from(logical_plan: logical::LogicalPlan) -> Result<Self, Self::Error> {\n        match logical_plan {\n            logical::LogicalPlan::Extension(extension) => {\n                if let Some(ext) = extension.node.as_any().downcast_ref::<DropModelPlanNode>() {\n                    Ok(PyDropModel {\n                        drop_model: ext.clone(),\n                    })\n                } else {\n                    Err(py_type_err(\"unexpected plan\"))\n                }\n            }\n            _ => Err(py_type_err(\"unexpected plan\")),\n        }\n    }\n}\n"
  },
  {
    "path": "src/sql/logical/drop_schema.rs",
    "content": "use std::{\n    any::Any,\n    fmt,\n    hash::{Hash, Hasher},\n    sync::Arc,\n};\n\nuse datafusion_python::{\n    datafusion_common::{DFSchema, DFSchemaRef},\n    datafusion_expr::{logical_plan::UserDefinedLogicalNode, Expr, LogicalPlan},\n};\nuse fmt::Debug;\nuse pyo3::prelude::*;\n\nuse crate::sql::{exceptions::py_type_err, logical};\n\n#[derive(Clone, PartialEq)]\npub struct DropSchemaPlanNode {\n    pub schema: DFSchemaRef,\n    pub schema_name: String,\n    pub if_exists: bool,\n}\n\nimpl Debug for DropSchemaPlanNode {\n    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {\n        self.fmt_for_explain(f)\n    }\n}\n\nimpl Hash for DropSchemaPlanNode {\n    fn hash<H: Hasher>(&self, state: &mut H) {\n        self.schema.hash(state);\n        self.schema_name.hash(state);\n        self.if_exists.hash(state);\n    }\n}\n\nimpl UserDefinedLogicalNode for DropSchemaPlanNode {\n    fn as_any(&self) -> &dyn Any {\n        self\n    }\n\n    fn inputs(&self) -> Vec<&LogicalPlan> {\n        vec![]\n    }\n\n    fn schema(&self) -> &DFSchemaRef {\n        &self.schema\n    }\n\n    fn expressions(&self) -> Vec<Expr> {\n        // there is no need to expose any expressions here since DataFusion would\n        // not be able to do anything with expressions that are specific to\n        // DROP SCHEMA\n        vec![]\n    }\n\n    fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result {\n        write!(f, \"DropSchema: schema_name={}\", self.schema_name)\n    }\n\n    fn from_template(\n        &self,\n        _exprs: &[Expr],\n        _inputs: &[LogicalPlan],\n    ) -> Arc<dyn UserDefinedLogicalNode> {\n        Arc::new(DropSchemaPlanNode {\n            schema: Arc::new(DFSchema::empty()),\n            schema_name: self.schema_name.clone(),\n            if_exists: self.if_exists,\n        })\n    }\n\n    fn name(&self) -> &str {\n        \"DropSchema\"\n    }\n\n    fn dyn_hash(&self, state: &mut dyn Hasher) {\n        let mut s = state;\n        self.hash(&mut s);\n    }\n\n    fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool {\n        match other.as_any().downcast_ref::<Self>() {\n            Some(o) => self == o,\n            None => false,\n        }\n    }\n}\n\n#[pyclass(name = \"DropSchema\", module = \"dask_sql\", subclass)]\npub struct PyDropSchema {\n    pub(crate) drop_schema: DropSchemaPlanNode,\n}\n\n#[pymethods]\nimpl PyDropSchema {\n    #[pyo3(name = \"getSchemaName\")]\n    fn get_schema_name(&self) -> PyResult<String> {\n        Ok(self.drop_schema.schema_name.clone())\n    }\n\n    #[pyo3(name = \"getIfExists\")]\n    fn get_if_exists(&self) -> PyResult<bool> {\n        Ok(self.drop_schema.if_exists)\n    }\n}\n\nimpl TryFrom<logical::LogicalPlan> for PyDropSchema {\n    type Error = PyErr;\n\n    fn try_from(logical_plan: logical::LogicalPlan) -> Result<Self, Self::Error> {\n        match logical_plan {\n            logical::LogicalPlan::Extension(extension) => {\n                if let Some(ext) = extension.node.as_any().downcast_ref::<DropSchemaPlanNode>() {\n                    Ok(PyDropSchema {\n                        drop_schema: ext.clone(),\n                    })\n                } else {\n                    Err(py_type_err(\"unexpected plan\"))\n                }\n            }\n            _ => Err(py_type_err(\"unexpected plan\")),\n        }\n    }\n}\n"
  },
  {
    "path": "src/sql/logical/drop_table.rs",
    "content": "use datafusion_python::datafusion_expr::{\n    logical_plan::{DropTable, LogicalPlan},\n    DdlStatement,\n};\nuse pyo3::prelude::*;\n\nuse crate::sql::exceptions::py_type_err;\n\n#[pyclass(name = \"DropTable\", module = \"dask_sql\", subclass)]\n#[derive(Clone)]\npub struct PyDropTable {\n    drop_table: DropTable,\n}\n\n#[pymethods]\nimpl PyDropTable {\n    #[pyo3(name = \"getQualifiedName\")]\n    pub fn get_name(&self) -> PyResult<String> {\n        Ok(self.drop_table.name.to_string())\n    }\n\n    #[pyo3(name = \"getIfExists\")]\n    pub fn get_if_exists(&self) -> PyResult<bool> {\n        Ok(self.drop_table.if_exists)\n    }\n}\n\nimpl TryFrom<LogicalPlan> for PyDropTable {\n    type Error = PyErr;\n\n    fn try_from(logical_plan: LogicalPlan) -> Result<Self, Self::Error> {\n        match logical_plan {\n            LogicalPlan::Ddl(DdlStatement::DropTable(drop_table)) => Ok(PyDropTable { drop_table }),\n            _ => Err(py_type_err(\"unexpected plan\")),\n        }\n    }\n}\n"
  },
  {
    "path": "src/sql/logical/empty_relation.rs",
    "content": "use datafusion_python::datafusion_expr::logical_plan::{EmptyRelation, LogicalPlan};\nuse pyo3::prelude::*;\n\nuse crate::sql::exceptions::py_type_err;\n\n#[pyclass(name = \"EmptyRelation\", module = \"dask_sql\", subclass)]\n#[derive(Clone)]\npub struct PyEmptyRelation {\n    empty_relation: EmptyRelation,\n}\n\nimpl TryFrom<LogicalPlan> for PyEmptyRelation {\n    type Error = PyErr;\n\n    fn try_from(logical_plan: LogicalPlan) -> Result<Self, Self::Error> {\n        match logical_plan {\n            LogicalPlan::EmptyRelation(empty_relation) => Ok(PyEmptyRelation { empty_relation }),\n            _ => Err(py_type_err(\"unexpected plan\")),\n        }\n    }\n}\n\n#[pymethods]\nimpl PyEmptyRelation {\n    /// Even though a relation results in an \"empty\" table column names\n    /// will still be projected and must be captured in order to present\n    /// the expected output to the user. This logic captures the names\n    /// of those columns and returns them to the Python logic where\n    /// there are rendered to the user\n    #[pyo3(name = \"emptyColumnNames\")]\n    pub fn empty_column_names(&self) -> PyResult<Vec<String>> {\n        Ok(self.empty_relation.schema.field_names())\n    }\n}\n"
  },
  {
    "path": "src/sql/logical/explain.rs",
    "content": "use datafusion_python::datafusion_expr::{logical_plan::Explain, LogicalPlan};\nuse pyo3::prelude::*;\n\nuse crate::sql::exceptions::py_type_err;\n\n#[pyclass(name = \"Explain\", module = \"dask_sql\", subclass)]\n#[derive(Clone)]\npub struct PyExplain {\n    explain: Explain,\n}\n\n#[pymethods]\nimpl PyExplain {\n    /// Returns explain strings\n    #[pyo3(name = \"getExplainString\")]\n    pub fn get_explain_string(&self) -> PyResult<Vec<String>> {\n        let mut string_plans: Vec<String> = Vec::new();\n        for stringified_plan in &self.explain.stringified_plans {\n            string_plans.push((*stringified_plan.plan).clone());\n        }\n        Ok(string_plans)\n    }\n}\n\nimpl TryFrom<LogicalPlan> for PyExplain {\n    type Error = PyErr;\n\n    fn try_from(logical_plan: LogicalPlan) -> Result<Self, Self::Error> {\n        match logical_plan {\n            LogicalPlan::Explain(explain) => Ok(PyExplain { explain }),\n            _ => Err(py_type_err(\"unexpected plan\")),\n        }\n    }\n}\n"
  },
  {
    "path": "src/sql/logical/export_model.rs",
    "content": "use std::{\n    any::Any,\n    fmt,\n    hash::{Hash, Hasher},\n    sync::Arc,\n};\n\nuse datafusion_python::{\n    datafusion_common::{DFSchema, DFSchemaRef},\n    datafusion_expr::{logical_plan::UserDefinedLogicalNode, Expr, LogicalPlan},\n};\nuse fmt::Debug;\nuse pyo3::prelude::*;\n\nuse crate::{\n    parser::PySqlArg,\n    sql::{exceptions::py_type_err, logical},\n};\n\n#[derive(Clone, PartialEq)]\npub struct ExportModelPlanNode {\n    pub schema: DFSchemaRef,\n    pub schema_name: Option<String>,\n    pub model_name: String,\n    pub with_options: Vec<(String, PySqlArg)>,\n}\n\nimpl Debug for ExportModelPlanNode {\n    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {\n        self.fmt_for_explain(f)\n    }\n}\n\nimpl Hash for ExportModelPlanNode {\n    fn hash<H: Hasher>(&self, state: &mut H) {\n        self.schema.hash(state);\n        self.schema_name.hash(state);\n        self.model_name.hash(state);\n        // self.with_options.hash(state);\n    }\n}\n\nimpl UserDefinedLogicalNode for ExportModelPlanNode {\n    fn as_any(&self) -> &dyn Any {\n        self\n    }\n\n    fn inputs(&self) -> Vec<&LogicalPlan> {\n        vec![]\n    }\n\n    fn schema(&self) -> &DFSchemaRef {\n        &self.schema\n    }\n\n    fn expressions(&self) -> Vec<Expr> {\n        // there is no need to expose any expressions here since DataFusion would\n        // not be able to do anything with expressions that are specific to\n        // EXPORT MODEL\n        vec![]\n    }\n\n    fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result {\n        write!(f, \"ExportModel: model_name={}\", self.model_name)\n    }\n\n    fn from_template(\n        &self,\n        _exprs: &[Expr],\n        inputs: &[LogicalPlan],\n    ) -> Arc<dyn UserDefinedLogicalNode> {\n        assert_eq!(inputs.len(), 0, \"input size inconsistent\");\n        Arc::new(ExportModelPlanNode {\n            schema: Arc::new(DFSchema::empty()),\n            schema_name: self.schema_name.clone(),\n            model_name: self.model_name.clone(),\n            with_options: self.with_options.clone(),\n        })\n    }\n\n    fn name(&self) -> &str {\n        \"ExportModel\"\n    }\n\n    fn dyn_hash(&self, state: &mut dyn Hasher) {\n        let mut s = state;\n        self.hash(&mut s);\n    }\n\n    fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool {\n        match other.as_any().downcast_ref::<Self>() {\n            Some(o) => self == o,\n            None => false,\n        }\n    }\n}\n\n#[pyclass(name = \"ExportModel\", module = \"dask_sql\", subclass)]\npub struct PyExportModel {\n    pub(crate) export_model: ExportModelPlanNode,\n}\n\n#[pymethods]\nimpl PyExportModel {\n    #[pyo3(name = \"getModelName\")]\n    fn get_model_name(&self) -> PyResult<String> {\n        Ok(self.export_model.model_name.clone())\n    }\n\n    #[pyo3(name = \"getSchemaName\")]\n    fn get_schema_name(&self) -> PyResult<Option<String>> {\n        Ok(self.export_model.schema_name.clone())\n    }\n\n    #[pyo3(name = \"getSQLWithOptions\")]\n    fn sql_with_options(&self) -> PyResult<Vec<(String, PySqlArg)>> {\n        Ok(self.export_model.with_options.clone())\n    }\n}\n\nimpl TryFrom<logical::LogicalPlan> for PyExportModel {\n    type Error = PyErr;\n\n    fn try_from(logical_plan: logical::LogicalPlan) -> Result<Self, Self::Error> {\n        match logical_plan {\n            logical::LogicalPlan::Extension(extension) => {\n                if let Some(ext) = extension\n                    .node\n                    .as_any()\n                    .downcast_ref::<ExportModelPlanNode>()\n                {\n                    Ok(PyExportModel {\n                        export_model: ext.clone(),\n                    })\n                } else {\n                    Err(py_type_err(\"unexpected plan\"))\n                }\n            }\n            _ => Err(py_type_err(\"unexpected plan\")),\n        }\n    }\n}\n"
  },
  {
    "path": "src/sql/logical/filter.rs",
    "content": "use datafusion_python::datafusion_expr::{logical_plan::Filter, LogicalPlan};\nuse pyo3::prelude::*;\n\nuse crate::{expression::PyExpr, sql::exceptions::py_type_err};\n\n#[pyclass(name = \"Filter\", module = \"dask_sql\", subclass)]\n#[derive(Clone)]\npub struct PyFilter {\n    filter: Filter,\n}\n\n#[pymethods]\nimpl PyFilter {\n    /// LogicalPlan::Filter: The PyExpr, predicate, that represents the filtering condition\n    #[pyo3(name = \"getCondition\")]\n    pub fn get_condition(&mut self) -> PyResult<PyExpr> {\n        Ok(PyExpr::from(\n            self.filter.predicate.clone(),\n            Some(vec![self.filter.input.clone()]),\n        ))\n    }\n}\n\nimpl TryFrom<LogicalPlan> for PyFilter {\n    type Error = PyErr;\n\n    fn try_from(logical_plan: LogicalPlan) -> Result<Self, Self::Error> {\n        match logical_plan {\n            LogicalPlan::Filter(filter) => Ok(PyFilter { filter }),\n            _ => Err(py_type_err(\"unexpected plan\")),\n        }\n    }\n}\n"
  },
  {
    "path": "src/sql/logical/join.rs",
    "content": "use datafusion_python::{\n    datafusion_common::Column,\n    datafusion_expr::{\n        and,\n        logical_plan::{Join, JoinType, LogicalPlan},\n        BinaryExpr,\n        Expr,\n        Operator,\n    },\n};\nuse pyo3::prelude::*;\n\nuse crate::{\n    expression::PyExpr,\n    sql::{column, exceptions::py_type_err},\n};\n\n#[pyclass(name = \"Join\", module = \"dask_sql\", subclass)]\n#[derive(Clone)]\npub struct PyJoin {\n    join: Join,\n}\n\n#[pymethods]\nimpl PyJoin {\n    #[pyo3(name = \"getCondition\")]\n    pub fn join_condition(&self) -> PyResult<Option<PyExpr>> {\n        // equi-join filters\n        let mut filters: Vec<Expr> = self\n            .join\n            .on\n            .iter()\n            .map(|(l, r)| match (l, r) {\n                (Expr::Column(l), Expr::Column(r)) => {\n                    Ok(Expr::Column(l.clone()).eq(Expr::Column(r.clone())))\n                }\n                (Expr::Column(l), Expr::Cast(cast)) => {\n                    let right = Column::from_qualified_name(cast.expr.to_string());\n                    Ok(Expr::Column(l.clone()).eq(Expr::Column(right)))\n                }\n                (Expr::Column(l), Expr::BinaryExpr(bin_expr)) => {\n                    Ok(Expr::BinaryExpr(BinaryExpr::new(\n                        Box::new(Expr::Column(l.clone())),\n                        Operator::Eq,\n                        Box::new(Expr::BinaryExpr(bin_expr.clone())),\n                    )))\n                }\n                _ => Err(py_type_err(format!(\n                    \"unsupported join condition. Left: {l} - Right: {r}\"\n                ))),\n            })\n            .collect::<Result<Vec<_>, _>>()?;\n\n        // other filter conditions\n        if let Some(filter) = &self.join.filter {\n            filters.push(filter.clone());\n        }\n\n        if !filters.is_empty() {\n            let root_expr = filters[1..]\n                .iter()\n                .fold(filters[0].clone(), |acc, expr| and(acc, expr.clone()));\n\n            Ok(Some(PyExpr::from(\n                root_expr,\n                Some(vec![self.join.left.clone(), self.join.right.clone()]),\n            )))\n        } else {\n            Ok(None)\n        }\n    }\n\n    #[pyo3(name = \"getJoinConditions\")]\n    pub fn join_conditions(&mut self) -> PyResult<Vec<(column::PyColumn, column::PyColumn)>> {\n        // let lhs_table_name = match &*self.join.left {\n        //     LogicalPlan::TableScan(scan) => scan.table_name.clone(),\n        //     _ => {\n        //         return Err(py_type_err(\n        //             \"lhs Expected TableScan but something else was received!\",\n        //         ))\n        //     }\n        // };\n\n        // let rhs_table_name = match &*self.join.right {\n        //     LogicalPlan::TableScan(scan) => scan.table_name.clone(),\n        //     _ => {\n        //         return Err(py_type_err(\n        //             \"rhs Expected TableScan but something else was received!\",\n        //         ))\n        //     }\n        // };\n\n        let mut join_conditions: Vec<(column::PyColumn, column::PyColumn)> = Vec::new();\n        for (lhs, rhs) in self.join.on.clone() {\n            match (lhs, rhs) {\n                (Expr::Column(lhs), Expr::Column(rhs)) => {\n                    join_conditions.push((lhs.into(), rhs.into()));\n                }\n                _ => return Err(py_type_err(\"unsupported join condition\")),\n            }\n        }\n        Ok(join_conditions)\n    }\n\n    /// Returns the type of join represented by this LogicalPlan::Join instance\n    #[pyo3(name = \"getJoinType\")]\n    pub fn join_type(&mut self) -> PyResult<String> {\n        match self.join.join_type {\n            JoinType::Inner => Ok(\"INNER\".to_string()),\n            JoinType::Left => Ok(\"LEFT\".to_string()),\n            JoinType::Right => Ok(\"RIGHT\".to_string()),\n            JoinType::Full => Ok(\"FULL\".to_string()),\n            JoinType::LeftSemi => Ok(\"LEFTSEMI\".to_string()),\n            JoinType::LeftAnti => Ok(\"LEFTANTI\".to_string()),\n            JoinType::RightSemi => Ok(\"RIGHTSEMI\".to_string()),\n            JoinType::RightAnti => Ok(\"RIGHTANTI\".to_string()),\n        }\n    }\n}\n\nimpl TryFrom<LogicalPlan> for PyJoin {\n    type Error = PyErr;\n\n    fn try_from(logical_plan: LogicalPlan) -> Result<Self, Self::Error> {\n        match logical_plan {\n            LogicalPlan::Join(join) => Ok(PyJoin { join }),\n            _ => Err(py_type_err(\"unexpected plan\")),\n        }\n    }\n}\n"
  },
  {
    "path": "src/sql/logical/limit.rs",
    "content": "use datafusion_python::{\n    datafusion_common::ScalarValue,\n    datafusion_expr::{logical_plan::Limit, Expr, LogicalPlan},\n};\nuse pyo3::prelude::*;\n\nuse crate::{expression::PyExpr, sql::exceptions::py_type_err};\n\n#[pyclass(name = \"Limit\", module = \"dask_sql\", subclass)]\n#[derive(Clone)]\npub struct PyLimit {\n    limit: Limit,\n}\n\n#[pymethods]\nimpl PyLimit {\n    /// `OFFSET` specified in the query\n    #[pyo3(name = \"getSkip\")]\n    pub fn skip(&self) -> PyResult<PyExpr> {\n        Ok(PyExpr::from(\n            Expr::Literal(ScalarValue::UInt64(Some(self.limit.skip as u64))),\n            Some(vec![self.limit.input.clone()]),\n        ))\n    }\n\n    /// `LIMIT` specified in the query\n    #[pyo3(name = \"getFetch\")]\n    pub fn fetch(&self) -> PyResult<PyExpr> {\n        Ok(PyExpr::from(\n            Expr::Literal(ScalarValue::UInt64(Some(\n                self.limit.fetch.unwrap_or(0) as u64\n            ))),\n            Some(vec![self.limit.input.clone()]),\n        ))\n    }\n}\n\nimpl TryFrom<LogicalPlan> for PyLimit {\n    type Error = PyErr;\n\n    fn try_from(logical_plan: LogicalPlan) -> Result<Self, Self::Error> {\n        match logical_plan {\n            LogicalPlan::Limit(limit) => Ok(PyLimit { limit }),\n            _ => Err(py_type_err(\"unexpected plan\")),\n        }\n    }\n}\n"
  },
  {
    "path": "src/sql/logical/predict_model.rs",
    "content": "use std::{\n    any::Any,\n    fmt,\n    hash::{Hash, Hasher},\n    sync::Arc,\n};\n\nuse datafusion_python::{\n    datafusion_common::DFSchemaRef,\n    datafusion_expr::{logical_plan::UserDefinedLogicalNode, Expr, LogicalPlan},\n};\nuse fmt::Debug;\nuse pyo3::prelude::*;\n\nuse super::PyLogicalPlan;\nuse crate::sql::{exceptions::py_type_err, logical};\n\n#[derive(Clone, PartialEq)]\npub struct PredictModelPlanNode {\n    pub schema_name: Option<String>, // \"something\" in `something.model_name`\n    pub model_name: String,\n    pub input: LogicalPlan,\n}\n\nimpl Debug for PredictModelPlanNode {\n    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {\n        self.fmt_for_explain(f)\n    }\n}\n\nimpl Hash for PredictModelPlanNode {\n    fn hash<H: Hasher>(&self, state: &mut H) {\n        self.schema_name.hash(state);\n        self.model_name.hash(state);\n        self.input.hash(state);\n    }\n}\n\nimpl UserDefinedLogicalNode for PredictModelPlanNode {\n    fn as_any(&self) -> &dyn Any {\n        self\n    }\n\n    fn inputs(&self) -> Vec<&LogicalPlan> {\n        vec![&self.input]\n    }\n\n    fn schema(&self) -> &DFSchemaRef {\n        self.input.schema()\n    }\n\n    fn expressions(&self) -> Vec<Expr> {\n        // there is no need to expose any expressions here since DataFusion would\n        // not be able to do anything with expressions that are specific to\n        // PREDICT TABLE\n        vec![]\n    }\n\n    fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result {\n        write!(f, \"PredictModel: model_name={}\", self.model_name)\n    }\n\n    fn from_template(\n        &self,\n        _exprs: &[Expr],\n        inputs: &[LogicalPlan],\n    ) -> Arc<dyn UserDefinedLogicalNode> {\n        Arc::new(PredictModelPlanNode {\n            schema_name: self.schema_name.clone(),\n            model_name: self.model_name.clone(),\n            input: inputs[0].clone(),\n        })\n    }\n\n    fn name(&self) -> &str {\n        \"PredictModel\"\n    }\n\n    fn dyn_hash(&self, state: &mut dyn Hasher) {\n        let mut s = state;\n        self.hash(&mut s);\n    }\n\n    fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool {\n        match other.as_any().downcast_ref::<Self>() {\n            Some(o) => self == o,\n            None => false,\n        }\n    }\n}\n\n#[pyclass(name = \"PredictModel\", module = \"dask_sql\", subclass)]\npub struct PyPredictModel {\n    pub(crate) predict_model: PredictModelPlanNode,\n}\n\n#[pymethods]\nimpl PyPredictModel {\n    #[pyo3(name = \"getSchemaName\")]\n    fn get_schema_name(&self) -> PyResult<Option<String>> {\n        Ok(self.predict_model.schema_name.clone())\n    }\n\n    #[pyo3(name = \"getModelName\")]\n    fn get_model_name(&self) -> PyResult<String> {\n        Ok(self.predict_model.model_name.clone())\n    }\n\n    #[pyo3(name = \"getSelect\")]\n    fn get_select(&self) -> PyResult<PyLogicalPlan> {\n        Ok(PyLogicalPlan::from(self.predict_model.input.clone()))\n    }\n}\n\nimpl TryFrom<logical::LogicalPlan> for PyPredictModel {\n    type Error = PyErr;\n\n    fn try_from(logical_plan: logical::LogicalPlan) -> Result<Self, Self::Error> {\n        match logical_plan {\n            logical::LogicalPlan::Extension(extension) => {\n                if let Some(ext) = extension\n                    .node\n                    .as_any()\n                    .downcast_ref::<PredictModelPlanNode>()\n                {\n                    Ok(PyPredictModel {\n                        predict_model: ext.clone(),\n                    })\n                } else {\n                    Err(py_type_err(\"unexpected plan\"))\n                }\n            }\n            _ => Err(py_type_err(\"unexpected plan\")),\n        }\n    }\n}\n"
  },
  {
    "path": "src/sql/logical/projection.rs",
    "content": "use datafusion_python::datafusion_expr::{\n    expr::Alias,\n    logical_plan::Projection,\n    Expr,\n    LogicalPlan,\n};\nuse pyo3::prelude::*;\n\nuse crate::{expression::PyExpr, sql::exceptions::py_type_err};\n\n#[pyclass(name = \"Projection\", module = \"dask_sql\", subclass)]\n#[derive(Clone)]\npub struct PyProjection {\n    pub(crate) projection: Projection,\n}\n\nimpl PyProjection {\n    /// Projection: Gets the names of the fields that should be projected\n    fn projected_expressions(&mut self, local_expr: &PyExpr) -> Vec<PyExpr> {\n        let mut projs: Vec<PyExpr> = Vec::new();\n        match &local_expr.expr {\n            Expr::Alias(Alias { expr, .. }) => {\n                let py_expr: PyExpr =\n                    PyExpr::from(*expr.clone(), Some(vec![self.projection.input.clone()]));\n                projs.extend_from_slice(self.projected_expressions(&py_expr).as_slice());\n            }\n            _ => projs.push(local_expr.clone()),\n        }\n        projs\n    }\n}\n\n#[pymethods]\nimpl PyProjection {\n    #[pyo3(name = \"getNamedProjects\")]\n    fn named_projects(&mut self) -> PyResult<Vec<(String, PyExpr)>> {\n        let mut named: Vec<(String, PyExpr)> = Vec::new();\n        for expression in self.projection.expr.clone() {\n            let py_expr: PyExpr =\n                PyExpr::from(expression, Some(vec![self.projection.input.clone()]));\n            for expr in self.projected_expressions(&py_expr) {\n                match expr.expr {\n                    Expr::Alias(Alias { expr, name }) => named.push((\n                        name.to_string(),\n                        PyExpr::from(*expr, Some(vec![self.projection.input.clone()])),\n                    )),\n                    _ => {\n                        if let Ok(name) = expr._column_name(&self.projection.input) {\n                            named.push((name, expr.clone()));\n                        }\n                    }\n                }\n            }\n        }\n        Ok(named)\n    }\n}\n\nimpl TryFrom<LogicalPlan> for PyProjection {\n    type Error = PyErr;\n\n    fn try_from(logical_plan: LogicalPlan) -> Result<Self, Self::Error> {\n        match logical_plan {\n            LogicalPlan::Projection(projection) => Ok(PyProjection { projection }),\n            _ => Err(py_type_err(\"unexpected plan\")),\n        }\n    }\n}\n"
  },
  {
    "path": "src/sql/logical/repartition_by.rs",
    "content": "use datafusion_python::datafusion_expr::{\n    logical_plan::{Partitioning, Repartition},\n    Expr,\n    LogicalPlan,\n};\nuse pyo3::prelude::*;\n\nuse crate::{\n    expression::PyExpr,\n    sql::{exceptions::py_type_err, logical},\n};\n\n#[pyclass(name = \"RepartitionBy\", module = \"dask_sql\", subclass)]\npub struct PyRepartitionBy {\n    pub(crate) repartition: Repartition,\n}\n\n#[pymethods]\nimpl PyRepartitionBy {\n    #[pyo3(name = \"getSelectQuery\")]\n    fn get_select_query(&self) -> PyResult<logical::PyLogicalPlan> {\n        let log_plan = &*(self.repartition.input).clone();\n        Ok(log_plan.clone().into())\n    }\n\n    #[pyo3(name = \"getDistributeList\")]\n    fn get_distribute_list(&self) -> PyResult<Vec<PyExpr>> {\n        match &self.repartition.partitioning_scheme {\n            Partitioning::DistributeBy(distribute_list) => Ok(distribute_list\n                .iter()\n                .map(|e| PyExpr::from(e.clone(), Some(vec![self.repartition.input.clone()])))\n                .collect()),\n            _ => Err(py_type_err(\"unexpected repartition strategy\")),\n        }\n    }\n\n    #[pyo3(name = \"getDistributionColumns\")]\n    fn get_distribute_columns(&self) -> PyResult<String> {\n        match &self.repartition.partitioning_scheme {\n            Partitioning::DistributeBy(distribute_list) => Ok(distribute_list\n                .iter()\n                .map(|e| match &e {\n                    Expr::Column(column) => column.name.clone(),\n                    _ => panic!(\"Encountered a type other than Expr::Column\"),\n                })\n                .collect()),\n            _ => Err(py_type_err(\"unexpected repartition strategy\")),\n        }\n    }\n}\n\nimpl TryFrom<LogicalPlan> for PyRepartitionBy {\n    type Error = PyErr;\n\n    fn try_from(logical_plan: LogicalPlan) -> Result<Self, Self::Error> {\n        match logical_plan {\n            LogicalPlan::Repartition(repartition) => Ok(PyRepartitionBy { repartition }),\n            _ => Err(py_type_err(\"unexpected plan\")),\n        }\n    }\n}\n"
  },
  {
    "path": "src/sql/logical/show_columns.rs",
    "content": "use std::{\n    any::Any,\n    fmt,\n    hash::{Hash, Hasher},\n    sync::Arc,\n};\n\nuse datafusion_python::{\n    datafusion_common::{DFSchema, DFSchemaRef},\n    datafusion_expr::{\n        logical_plan::{Extension, UserDefinedLogicalNode},\n        Expr,\n        LogicalPlan,\n    },\n};\nuse fmt::Debug;\nuse pyo3::prelude::*;\n\nuse crate::sql::{exceptions::py_type_err, logical};\n\n#[derive(Clone, PartialEq)]\npub struct ShowColumnsPlanNode {\n    pub schema: DFSchemaRef,\n    pub table_name: String,\n    pub schema_name: Option<String>,\n}\n\nimpl Debug for ShowColumnsPlanNode {\n    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {\n        self.fmt_for_explain(f)\n    }\n}\n\nimpl Hash for ShowColumnsPlanNode {\n    fn hash<H: Hasher>(&self, state: &mut H) {\n        self.schema.hash(state);\n        self.table_name.hash(state);\n        self.schema_name.hash(state);\n    }\n}\n\nimpl UserDefinedLogicalNode for ShowColumnsPlanNode {\n    fn as_any(&self) -> &dyn Any {\n        self\n    }\n\n    fn inputs(&self) -> Vec<&LogicalPlan> {\n        vec![]\n    }\n\n    fn schema(&self) -> &DFSchemaRef {\n        &self.schema\n    }\n\n    fn expressions(&self) -> Vec<Expr> {\n        // there is no need to expose any expressions here since DataFusion would\n        // not be able to do anything with expressions that are specific to\n        // SHOW COLUMNS FROM {table_name}\n        vec![]\n    }\n\n    fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result {\n        write!(f, \"Show Columns: table_name: {:?}\", self.table_name)\n    }\n\n    fn from_template(\n        &self,\n        _exprs: &[Expr],\n        _inputs: &[LogicalPlan],\n    ) -> Arc<dyn UserDefinedLogicalNode> {\n        Arc::new(ShowColumnsPlanNode {\n            schema: Arc::new(DFSchema::empty()),\n            table_name: self.table_name.clone(),\n            schema_name: self.schema_name.clone(),\n        })\n    }\n\n    fn name(&self) -> &str {\n        \"ShowColumns\"\n    }\n\n    fn dyn_hash(&self, state: &mut dyn Hasher) {\n        let mut s = state;\n        self.hash(&mut s);\n    }\n\n    fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool {\n        match other.as_any().downcast_ref::<Self>() {\n            Some(o) => self == o,\n            None => false,\n        }\n    }\n}\n\n#[pyclass(name = \"ShowColumns\", module = \"dask_sql\", subclass)]\npub struct PyShowColumns {\n    pub(crate) show_columns: ShowColumnsPlanNode,\n}\n\n#[pymethods]\nimpl PyShowColumns {\n    #[pyo3(name = \"getTableName\")]\n    fn get_table_name(&self) -> PyResult<String> {\n        Ok(self.show_columns.table_name.clone())\n    }\n\n    #[pyo3(name = \"getSchemaName\")]\n    fn get_schema_name(&self) -> PyResult<Option<String>> {\n        Ok(self.show_columns.schema_name.clone())\n    }\n}\n\nimpl TryFrom<logical::LogicalPlan> for PyShowColumns {\n    type Error = PyErr;\n\n    fn try_from(logical_plan: logical::LogicalPlan) -> Result<Self, Self::Error> {\n        match logical_plan {\n            LogicalPlan::Extension(Extension { node })\n                if node\n                    .as_any()\n                    .downcast_ref::<ShowColumnsPlanNode>()\n                    .is_some() =>\n            {\n                let ext = node\n                    .as_any()\n                    .downcast_ref::<ShowColumnsPlanNode>()\n                    .expect(\"ShowColumnsPlanNode\");\n                Ok(PyShowColumns {\n                    show_columns: ext.clone(),\n                })\n            }\n            _ => Err(py_type_err(\"unexpected plan\")),\n        }\n    }\n}\n"
  },
  {
    "path": "src/sql/logical/show_models.rs",
    "content": "use std::{\n    any::Any,\n    fmt,\n    hash::{Hash, Hasher},\n    sync::Arc,\n};\n\nuse datafusion_python::{\n    datafusion_common::{DFSchema, DFSchemaRef},\n    datafusion_expr::{logical_plan::UserDefinedLogicalNode, Expr, LogicalPlan},\n};\nuse fmt::Debug;\nuse pyo3::prelude::*;\n\nuse crate::sql::logical::py_type_err;\n\n#[derive(Clone, PartialEq)]\npub struct ShowModelsPlanNode {\n    pub schema: DFSchemaRef,\n    pub schema_name: Option<String>,\n}\n\nimpl Debug for ShowModelsPlanNode {\n    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {\n        self.fmt_for_explain(f)\n    }\n}\n\nimpl Hash for ShowModelsPlanNode {\n    fn hash<H: Hasher>(&self, state: &mut H) {\n        self.schema.hash(state);\n        self.schema_name.hash(state);\n    }\n}\n\nimpl UserDefinedLogicalNode for ShowModelsPlanNode {\n    fn as_any(&self) -> &dyn Any {\n        self\n    }\n\n    fn inputs(&self) -> Vec<&LogicalPlan> {\n        vec![]\n    }\n\n    fn schema(&self) -> &DFSchemaRef {\n        &self.schema\n    }\n\n    fn expressions(&self) -> Vec<Expr> {\n        // there is no need to expose any expressions here since DataFusion would\n        // not be able to do anything with expressions that are specific to\n        // SHOW MODELS\n        vec![]\n    }\n\n    fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result {\n        write!(f, \"ShowModels\")\n    }\n\n    fn from_template(\n        &self,\n        _exprs: &[Expr],\n        _inputs: &[LogicalPlan],\n    ) -> Arc<dyn UserDefinedLogicalNode> {\n        Arc::new(ShowModelsPlanNode {\n            schema: Arc::new(DFSchema::empty()),\n            schema_name: self.schema_name.clone(),\n        })\n    }\n\n    fn name(&self) -> &str {\n        \"ShowModels\"\n    }\n\n    fn dyn_hash(&self, state: &mut dyn Hasher) {\n        let mut s = state;\n        self.hash(&mut s);\n    }\n\n    fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool {\n        match other.as_any().downcast_ref::<Self>() {\n            Some(o) => self == o,\n            None => false,\n        }\n    }\n}\n\n#[pyclass(name = \"ShowModels\", module = \"dask_sql\", subclass)]\npub struct PyShowModels {\n    pub(crate) show_models: ShowModelsPlanNode,\n}\n\n#[pymethods]\nimpl PyShowModels {\n    #[pyo3(name = \"getSchemaName\")]\n    fn get_schema_name(&self) -> PyResult<Option<String>> {\n        Ok(self.show_models.schema_name.clone())\n    }\n}\n\nimpl TryFrom<LogicalPlan> for PyShowModels {\n    type Error = PyErr;\n\n    fn try_from(logical_plan: LogicalPlan) -> Result<Self, Self::Error> {\n        match logical_plan {\n            LogicalPlan::Extension(extension) => {\n                if let Some(ext) = extension.node.as_any().downcast_ref::<ShowModelsPlanNode>() {\n                    Ok(PyShowModels {\n                        show_models: ext.clone(),\n                    })\n                } else {\n                    Err(py_type_err(\"unexpected plan\"))\n                }\n            }\n            _ => Err(py_type_err(\"unexpected plan\")),\n        }\n    }\n}\n"
  },
  {
    "path": "src/sql/logical/show_schemas.rs",
    "content": "use std::{\n    any::Any,\n    fmt,\n    hash::{Hash, Hasher},\n    sync::Arc,\n};\n\nuse datafusion_python::{\n    datafusion_common::{DFSchema, DFSchemaRef},\n    datafusion_expr::{\n        logical_plan::{Extension, UserDefinedLogicalNode},\n        Expr,\n        LogicalPlan,\n    },\n};\nuse fmt::Debug;\nuse pyo3::prelude::*;\n\nuse crate::sql::{exceptions::py_type_err, logical};\n\n#[derive(Clone, PartialEq)]\npub struct ShowSchemasPlanNode {\n    pub schema: DFSchemaRef,\n    pub catalog_name: Option<String>,\n    pub like: Option<String>,\n}\n\nimpl Debug for ShowSchemasPlanNode {\n    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {\n        self.fmt_for_explain(f)\n    }\n}\n\nimpl Hash for ShowSchemasPlanNode {\n    fn hash<H: Hasher>(&self, state: &mut H) {\n        self.schema.hash(state);\n        self.like.hash(state);\n    }\n}\n\nimpl UserDefinedLogicalNode for ShowSchemasPlanNode {\n    fn as_any(&self) -> &dyn Any {\n        self\n    }\n\n    fn inputs(&self) -> Vec<&LogicalPlan> {\n        vec![]\n    }\n\n    fn schema(&self) -> &DFSchemaRef {\n        &self.schema\n    }\n\n    fn expressions(&self) -> Vec<Expr> {\n        // there is no need to expose any expressions here since DataFusion would\n        // not be able to do anything with expressions that are specific to\n        // SHOW SCHEMAS\n        vec![]\n    }\n\n    fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result {\n        write!(f, \"ShowSchema: catalog_name: {:?}\", self.catalog_name)\n    }\n\n    fn from_template(\n        &self,\n        _exprs: &[Expr],\n        _inputs: &[LogicalPlan],\n    ) -> Arc<dyn UserDefinedLogicalNode> {\n        Arc::new(ShowSchemasPlanNode {\n            schema: Arc::new(DFSchema::empty()),\n            catalog_name: self.catalog_name.clone(),\n            like: self.like.clone(),\n        })\n    }\n\n    fn name(&self) -> &str {\n        \"ShowSchema\"\n    }\n\n    fn dyn_hash(&self, state: &mut dyn Hasher) {\n        let mut s = state;\n        self.hash(&mut s);\n    }\n\n    fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool {\n        match other.as_any().downcast_ref::<Self>() {\n            Some(o) => self == o,\n            None => false,\n        }\n    }\n}\n\n#[pyclass(name = \"ShowSchema\", module = \"dask_sql\", subclass)]\npub struct PyShowSchema {\n    pub(crate) show_schema: ShowSchemasPlanNode,\n}\n\n#[pymethods]\nimpl PyShowSchema {\n    #[pyo3(name = \"getCatalogName\")]\n    fn get_from(&self) -> PyResult<Option<String>> {\n        Ok(self.show_schema.catalog_name.clone())\n    }\n\n    #[pyo3(name = \"getLike\")]\n    fn get_like(&self) -> PyResult<Option<String>> {\n        Ok(self.show_schema.like.clone())\n    }\n}\n\nimpl TryFrom<logical::LogicalPlan> for PyShowSchema {\n    type Error = PyErr;\n\n    fn try_from(logical_plan: logical::LogicalPlan) -> Result<Self, Self::Error> {\n        match logical_plan {\n            LogicalPlan::Extension(Extension { node })\n                if node\n                    .as_any()\n                    .downcast_ref::<ShowSchemasPlanNode>()\n                    .is_some() =>\n            {\n                let ext = node\n                    .as_any()\n                    .downcast_ref::<ShowSchemasPlanNode>()\n                    .expect(\"ShowSchemasPlanNode\");\n                Ok(PyShowSchema {\n                    show_schema: ext.clone(),\n                })\n            }\n            _ => Err(py_type_err(\"unexpected plan\")),\n        }\n    }\n}\n"
  },
  {
    "path": "src/sql/logical/show_tables.rs",
    "content": "use std::{\n    any::Any,\n    fmt,\n    hash::{Hash, Hasher},\n    sync::Arc,\n};\n\nuse datafusion_python::{\n    datafusion_common::{DFSchema, DFSchemaRef},\n    datafusion_expr::{\n        logical_plan::{Extension, UserDefinedLogicalNode},\n        Expr,\n        LogicalPlan,\n    },\n};\nuse fmt::Debug;\nuse pyo3::prelude::*;\n\nuse crate::sql::{exceptions::py_type_err, logical};\n\n#[derive(Clone, PartialEq)]\npub struct ShowTablesPlanNode {\n    pub schema: DFSchemaRef,\n    pub catalog_name: Option<String>,\n    pub schema_name: Option<String>,\n}\n\nimpl Debug for ShowTablesPlanNode {\n    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {\n        self.fmt_for_explain(f)\n    }\n}\n\nimpl Hash for ShowTablesPlanNode {\n    fn hash<H: Hasher>(&self, state: &mut H) {\n        self.schema.hash(state);\n        self.schema_name.hash(state);\n    }\n}\n\nimpl UserDefinedLogicalNode for ShowTablesPlanNode {\n    fn as_any(&self) -> &dyn Any {\n        self\n    }\n\n    fn inputs(&self) -> Vec<&LogicalPlan> {\n        vec![]\n    }\n\n    fn schema(&self) -> &DFSchemaRef {\n        &self.schema\n    }\n\n    fn expressions(&self) -> Vec<Expr> {\n        // there is no need to expose any expressions here since DataFusion would\n        // not be able to do anything with expressions that are specific to\n        // SHOW TABLES FROM {schema_name}\n        vec![]\n    }\n\n    fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result {\n        write!(\n            f,\n            \"ShowTables: catalog_name: {:?}, schema_name: {:?}\",\n            self.catalog_name, self.schema_name\n        )\n    }\n\n    fn from_template(\n        &self,\n        _exprs: &[Expr],\n        _inputs: &[LogicalPlan],\n    ) -> Arc<dyn UserDefinedLogicalNode> {\n        Arc::new(ShowTablesPlanNode {\n            schema: Arc::new(DFSchema::empty()),\n            catalog_name: self.catalog_name.clone(),\n            schema_name: self.schema_name.clone(),\n        })\n    }\n\n    fn name(&self) -> &str {\n        \"ShowTables\"\n    }\n\n    fn dyn_hash(&self, state: &mut dyn Hasher) {\n        let mut s = state;\n        self.hash(&mut s);\n    }\n\n    fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool {\n        match other.as_any().downcast_ref::<Self>() {\n            Some(o) => self == o,\n            None => false,\n        }\n    }\n}\n\n#[pyclass(name = \"ShowTables\", module = \"dask_sql\", subclass)]\npub struct PyShowTables {\n    pub(crate) show_tables: ShowTablesPlanNode,\n}\n\n#[pymethods]\nimpl PyShowTables {\n    #[pyo3(name = \"getCatalogName\")]\n    fn get_catalog_name(&self) -> PyResult<Option<String>> {\n        Ok(self.show_tables.catalog_name.clone())\n    }\n\n    #[pyo3(name = \"getSchemaName\")]\n    fn get_schema_name(&self) -> PyResult<Option<String>> {\n        Ok(self.show_tables.schema_name.clone())\n    }\n}\n\nimpl TryFrom<logical::LogicalPlan> for PyShowTables {\n    type Error = PyErr;\n\n    fn try_from(logical_plan: logical::LogicalPlan) -> Result<Self, Self::Error> {\n        match logical_plan {\n            LogicalPlan::Extension(Extension { node })\n                if node.as_any().downcast_ref::<ShowTablesPlanNode>().is_some() =>\n            {\n                let ext = node\n                    .as_any()\n                    .downcast_ref::<ShowTablesPlanNode>()\n                    .expect(\"ShowTablesPlanNode\");\n                Ok(PyShowTables {\n                    show_tables: ext.clone(),\n                })\n            }\n            _ => Err(py_type_err(\"unexpected plan\")),\n        }\n    }\n}\n"
  },
  {
    "path": "src/sql/logical/sort.rs",
    "content": "use datafusion_python::datafusion_expr::{logical_plan::Sort, LogicalPlan};\nuse pyo3::prelude::*;\n\nuse crate::{\n    expression::{py_expr_list, PyExpr},\n    sql::exceptions::py_type_err,\n};\n\n#[pyclass(name = \"Sort\", module = \"dask_sql\", subclass)]\n#[derive(Clone)]\npub struct PySort {\n    sort: Sort,\n}\n\n#[pymethods]\nimpl PySort {\n    /// Returns a Vec of the sort expressions\n    #[pyo3(name = \"getCollation\")]\n    pub fn sort_expressions(&self) -> PyResult<Vec<PyExpr>> {\n        py_expr_list(&self.sort.input, &self.sort.expr)\n    }\n\n    #[pyo3(name = \"getNumRows\")]\n    pub fn get_fetch_val(&self) -> PyResult<Option<usize>> {\n        Ok(self.sort.fetch)\n    }\n}\n\nimpl TryFrom<LogicalPlan> for PySort {\n    type Error = PyErr;\n\n    fn try_from(logical_plan: LogicalPlan) -> Result<Self, Self::Error> {\n        match logical_plan {\n            LogicalPlan::Sort(sort) => Ok(PySort { sort }),\n            _ => Err(py_type_err(\"unexpected plan\")),\n        }\n    }\n}\n"
  },
  {
    "path": "src/sql/logical/subquery_alias.rs",
    "content": "use datafusion_python::datafusion_expr::{logical_plan::SubqueryAlias, LogicalPlan};\nuse pyo3::prelude::*;\n\nuse crate::sql::exceptions::py_type_err;\n\n#[pyclass(name = \"SubqueryAlias\", module = \"dask_sql\", subclass)]\n#[derive(Clone)]\npub struct PySubqueryAlias {\n    subquery_alias: SubqueryAlias,\n}\n\n#[pymethods]\nimpl PySubqueryAlias {\n    /// Returns a Vec of the sort expressions\n    #[pyo3(name = \"getAlias\")]\n    pub fn alias(&self) -> PyResult<String> {\n        Ok(self.subquery_alias.alias.clone().to_string())\n    }\n}\n\nimpl TryFrom<LogicalPlan> for PySubqueryAlias {\n    type Error = PyErr;\n\n    fn try_from(logical_plan: LogicalPlan) -> Result<Self, Self::Error> {\n        match logical_plan {\n            LogicalPlan::SubqueryAlias(subquery_alias) => Ok(PySubqueryAlias { subquery_alias }),\n            _ => Err(py_type_err(\"unexpected plan\")),\n        }\n    }\n}\n"
  },
  {
    "path": "src/sql/logical/table_scan.rs",
    "content": "use std::{sync::Arc, vec};\n\nuse datafusion_python::{\n    datafusion_common::{DFSchema, ScalarValue},\n    datafusion_expr::{\n        expr::{Alias, InList},\n        logical_plan::TableScan,\n        Expr,\n        LogicalPlan,\n    },\n};\nuse pyo3::prelude::*;\n\nuse crate::{\n    error::DaskPlannerError,\n    expression::{py_expr_list, PyExpr},\n    sql::exceptions::py_type_err,\n};\n\n#[pyclass(name = \"TableScan\", module = \"dask_sql\", subclass)]\n#[derive(Clone)]\npub struct PyTableScan {\n    pub(crate) table_scan: TableScan,\n    input: Arc<LogicalPlan>,\n}\n\ntype FilterTuple = (String, String, Option<Vec<PyObject>>);\n#[pyclass(name = \"FilteredResult\", module = \"dask_sql\", subclass)]\n#[derive(Debug, Clone)]\npub struct PyFilteredResult {\n    // Certain Expr(s) do not have supporting logic in pyarrow for IO filtering\n    // at read time. Those Expr(s) cannot be ignored however. This field stores\n    // those Expr(s) so that they can be used on the Python side to create\n    // Dask operations that handle that filtering as an extra task in the graph.\n    #[pyo3(get)]\n    pub io_unfilterable_exprs: Vec<PyExpr>,\n    // Expr(s) that can have their filtering logic performed in the pyarrow IO logic\n    // are stored here in a DNF format that is expected by pyarrow.\n    #[pyo3(get)]\n    pub filtered_exprs: Vec<(PyExpr, FilterTuple)>,\n}\n\nimpl PyTableScan {\n    /// Ensures that a valid Expr variant type is present\n    fn _valid_expr_type(expr: &[Expr]) -> bool {\n        expr.iter()\n            .all(|f| matches!(f, Expr::Column(_) | Expr::Literal(_)))\n    }\n\n    /// Transform the singular Expr instance into its DNF form serialized in a Vec instance. Possibly recursively expanding\n    /// it as well if needed.\n    pub fn _expand_dnf_filter(\n        filter: &Expr,\n        input: &Arc<LogicalPlan>,\n        py: Python,\n    ) -> Result<Vec<(PyExpr, FilterTuple)>, DaskPlannerError> {\n        let mut filter_tuple: Vec<(PyExpr, FilterTuple)> = Vec::new();\n\n        match filter {\n            Expr::InList(InList {\n                expr,\n                list,\n                negated,\n            }) => {\n                // Only handle simple Expr(s) for InList operations for now\n                if PyTableScan::_valid_expr_type(list) {\n                    // While ANSI SQL would not allow for anything other than a Column or Literal\n                    // value in this \"identifying\" `expr` we explicitly check that here just to be sure.\n                    // IF it is something else it is returned to Dask to handle\n                    let ident = match *expr.clone() {\n                        Expr::Column(col) => Ok(col.name),\n                        Expr::Alias(Alias { name, .. }) => Ok(name),\n                        Expr::Literal(val) => Ok(format!(\"{}\", val)),\n                        _ => Err(DaskPlannerError::InvalidIOFilter(format!(\n                            \"Invalid InList Expr type `{}`. using in Dask instead\",\n                            filter\n                        ))),\n                    };\n\n                    let op = if *negated { \"not in\" } else { \"in\" };\n                    let il: Result<Vec<PyObject>, DaskPlannerError> = list\n                        .iter()\n                        .map(|f| match f {\n                            Expr::Column(col) => Ok(col.name.clone().into_py(py)),\n                            Expr::Alias(Alias { name, ..}) => Ok(name.clone().into_py(py)),\n                            Expr::Literal(val) => match val {\n                                ScalarValue::Boolean(val) => Ok(val.unwrap().into_py(py)),\n                                ScalarValue::Float32(val) => Ok(val.unwrap().into_py(py)),\n                                ScalarValue::Float64(val) => Ok(val.unwrap().into_py(py)),\n                                ScalarValue::Int8(val) => Ok(val.unwrap().into_py(py)),\n                                ScalarValue::Int16(val) => Ok(val.unwrap().into_py(py)),\n                                ScalarValue::Int32(val) => Ok(val.unwrap().into_py(py)),\n                                ScalarValue::Int64(val) => Ok(val.unwrap().into_py(py)),\n                                ScalarValue::UInt8(val) => Ok(val.unwrap().into_py(py)),\n                                ScalarValue::UInt16(val) => Ok(val.unwrap().into_py(py)),\n                                ScalarValue::UInt32(val) => Ok(val.unwrap().into_py(py)),\n                                ScalarValue::UInt64(val) => Ok(val.unwrap().into_py(py)),\n                                ScalarValue::Utf8(val) => Ok(val.clone().unwrap().into_py(py)),\n                                ScalarValue::LargeUtf8(val) => Ok(val.clone().unwrap().into_py(py)),\n                                _ => Err(DaskPlannerError::InvalidIOFilter(format!(\n                                    \"Unsupported ScalarValue `{}` encountered. using in Dask instead\",\n                                    filter\n                                ))),\n                            },\n                            _ => Ok(f.canonical_name().into_py(py)),\n                        })\n                        .collect();\n\n                    filter_tuple.push((\n                        PyExpr::from(filter.clone(), Some(vec![input.clone()])),\n                        (\n                            ident.unwrap_or(expr.canonical_name()),\n                            op.to_string(),\n                            Some(il?),\n                        ),\n                    ));\n                    Ok(filter_tuple)\n                } else {\n                    let er = DaskPlannerError::InvalidIOFilter(format!(\n                        \"Invalid identifying column Expr instance `{}`. using in Dask instead\",\n                        filter\n                    ));\n                    Err::<Vec<(PyExpr, FilterTuple)>, DaskPlannerError>(er)\n                }\n            }\n            Expr::IsNotNull(expr) => {\n                // Only handle simple Expr(s) for IsNotNull operations for now\n                let ident = match *expr.clone() {\n                    Expr::Column(col) => Ok(col.name),\n                    _ => Err(DaskPlannerError::InvalidIOFilter(format!(\n                        \"Invalid IsNotNull Expr type `{}`. using in Dask instead\",\n                        filter\n                    ))),\n                };\n\n                filter_tuple.push((\n                    PyExpr::from(filter.clone(), Some(vec![input.clone()])),\n                    (\n                        ident.unwrap_or(expr.canonical_name()),\n                        \"is not\".to_string(),\n                        None,\n                    ),\n                ));\n                Ok(filter_tuple)\n            }\n            _ => {\n                let er = DaskPlannerError::InvalidIOFilter(format!(\n                    \"Unable to apply filter: `{}` to IO reader, using in Dask instead\",\n                    filter\n                ));\n                Err::<Vec<(PyExpr, FilterTuple)>, DaskPlannerError>(er)\n            }\n        }\n    }\n\n    /// Consume the `TableScan` filters (Expr(s)) and convert them into a PyArrow understandable\n    /// DNF format that can be directly passed to PyArrow IO readers for Predicate Pushdown. Expr(s)\n    /// that cannot be converted to correlating PyArrow IO calls will be returned as is and can be\n    /// used in the Python logic to form Dask tasks for the graph to do computational filtering.\n    pub fn _expand_dnf_filters(\n        input: &Arc<LogicalPlan>,\n        filters: &[Expr],\n        py: Python,\n    ) -> PyFilteredResult {\n        let mut filtered_exprs: Vec<(PyExpr, FilterTuple)> = Vec::new();\n        let mut unfiltered_exprs: Vec<PyExpr> = Vec::new();\n\n        filters\n            .iter()\n            .for_each(|f| match PyTableScan::_expand_dnf_filter(f, input, py) {\n                Ok(mut expanded_dnf_filter) => filtered_exprs.append(&mut expanded_dnf_filter),\n                Err(_e) => {\n                    unfiltered_exprs.push(PyExpr::from(f.clone(), Some(vec![input.clone()])))\n                }\n            });\n\n        PyFilteredResult {\n            io_unfilterable_exprs: unfiltered_exprs,\n            filtered_exprs,\n        }\n    }\n}\n\n#[pymethods]\nimpl PyTableScan {\n    #[pyo3(name = \"getTableScanProjects\")]\n    fn scan_projects(&mut self) -> PyResult<Vec<String>> {\n        match &self.table_scan.projection {\n            Some(indices) => {\n                let schema = self.table_scan.source.schema();\n                Ok(indices\n                    .iter()\n                    .map(|i| schema.field(*i).name().to_string())\n                    .collect())\n            }\n            None => Ok(vec![]),\n        }\n    }\n\n    /// If the 'TableScan' contains columns that should be projected during the\n    /// read return True, otherwise return False\n    #[pyo3(name = \"containsProjections\")]\n    fn contains_projections(&self) -> bool {\n        self.table_scan.projection.is_some()\n    }\n\n    #[pyo3(name = \"getFilters\")]\n    fn scan_filters(&self) -> PyResult<Vec<PyExpr>> {\n        py_expr_list(&self.input, &self.table_scan.filters)\n    }\n\n    #[pyo3(name = \"getDNFFilters\")]\n    fn dnf_io_filters(&self, py: Python) -> PyResult<PyFilteredResult> {\n        let results = PyTableScan::_expand_dnf_filters(&self.input, &self.table_scan.filters, py);\n        Ok(results)\n    }\n}\n\nimpl TryFrom<LogicalPlan> for PyTableScan {\n    type Error = PyErr;\n\n    fn try_from(logical_plan: LogicalPlan) -> Result<Self, Self::Error> {\n        match logical_plan {\n            LogicalPlan::TableScan(table_scan) => {\n                // Create an input logical plan that's identical to the table scan with schema from the table source\n                let mut input = table_scan.clone();\n                input.projected_schema = DFSchema::try_from_qualified_schema(\n                    &table_scan.table_name,\n                    &table_scan.source.schema(),\n                )\n                .map_or(input.projected_schema, Arc::new);\n\n                Ok(PyTableScan {\n                    table_scan,\n                    input: Arc::new(LogicalPlan::TableScan(input)),\n                })\n            }\n            _ => Err(py_type_err(\"unexpected plan\")),\n        }\n    }\n}\n"
  },
  {
    "path": "src/sql/logical/use_schema.rs",
    "content": "use std::{\n    any::Any,\n    fmt,\n    hash::{Hash, Hasher},\n    sync::Arc,\n};\n\nuse datafusion_python::{\n    datafusion_common::{DFSchema, DFSchemaRef},\n    datafusion_expr::{logical_plan::UserDefinedLogicalNode, Expr, LogicalPlan},\n};\nuse fmt::Debug;\nuse pyo3::prelude::*;\n\nuse crate::sql::{exceptions::py_type_err, logical};\n\n#[derive(Clone, PartialEq)]\npub struct UseSchemaPlanNode {\n    pub schema: DFSchemaRef,\n    pub schema_name: String,\n}\n\nimpl Debug for UseSchemaPlanNode {\n    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {\n        self.fmt_for_explain(f)\n    }\n}\n\nimpl Hash for UseSchemaPlanNode {\n    fn hash<H: Hasher>(&self, state: &mut H) {\n        self.schema.hash(state);\n        self.schema_name.hash(state);\n    }\n}\n\nimpl UserDefinedLogicalNode for UseSchemaPlanNode {\n    fn as_any(&self) -> &dyn Any {\n        self\n    }\n\n    fn inputs(&self) -> Vec<&LogicalPlan> {\n        vec![]\n    }\n\n    fn schema(&self) -> &DFSchemaRef {\n        &self.schema\n    }\n\n    fn expressions(&self) -> Vec<Expr> {\n        // there is no need to expose any expressions here since DataFusion would\n        // not be able to do anything with expressions that are specific to\n        // USE SCHEMA\n        vec![]\n    }\n\n    fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result {\n        write!(f, \"UseSchema: schema_name={}\", self.schema_name)\n    }\n\n    fn from_template(\n        &self,\n        _exprs: &[Expr],\n        _inputs: &[LogicalPlan],\n    ) -> Arc<dyn UserDefinedLogicalNode> {\n        Arc::new(UseSchemaPlanNode {\n            schema: Arc::new(DFSchema::empty()),\n            schema_name: self.schema_name.clone(),\n        })\n    }\n\n    fn name(&self) -> &str {\n        \"UseSchema\"\n    }\n\n    fn dyn_hash(&self, state: &mut dyn Hasher) {\n        let mut s = state;\n        self.hash(&mut s);\n    }\n\n    fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool {\n        match other.as_any().downcast_ref::<Self>() {\n            Some(o) => self == o,\n            None => false,\n        }\n    }\n}\n\n#[pyclass(name = \"UseSchema\", module = \"dask_sql\", subclass)]\npub struct PyUseSchema {\n    pub(crate) use_schema: UseSchemaPlanNode,\n}\n\n#[pymethods]\nimpl PyUseSchema {\n    #[pyo3(name = \"getSchemaName\")]\n    fn get_schema_name(&self) -> PyResult<String> {\n        Ok(self.use_schema.schema_name.clone())\n    }\n}\n\nimpl TryFrom<logical::LogicalPlan> for PyUseSchema {\n    type Error = PyErr;\n\n    fn try_from(logical_plan: logical::LogicalPlan) -> Result<Self, Self::Error> {\n        match logical_plan {\n            logical::LogicalPlan::Extension(extension) => {\n                if let Some(ext) = extension.node.as_any().downcast_ref::<UseSchemaPlanNode>() {\n                    Ok(PyUseSchema {\n                        use_schema: ext.clone(),\n                    })\n                } else {\n                    Err(py_type_err(\"unexpected plan\"))\n                }\n            }\n            _ => Err(py_type_err(\"unexpected plan\")),\n        }\n    }\n}\n"
  },
  {
    "path": "src/sql/logical/window.rs",
    "content": "use datafusion_python::{\n    datafusion_common::ScalarValue,\n    datafusion_expr::{\n        expr::WindowFunction,\n        logical_plan::Window,\n        Expr,\n        LogicalPlan,\n        WindowFrame,\n        WindowFrameBound,\n    },\n};\nuse pyo3::prelude::*;\n\nuse crate::{\n    error::DaskPlannerError,\n    expression::{py_expr_list, PyExpr},\n    sql::exceptions::py_type_err,\n};\n\n#[pyclass(name = \"Window\", module = \"dask_sql\", subclass)]\n#[derive(Clone)]\npub struct PyWindow {\n    window: Window,\n}\n\n#[pyclass(name = \"WindowFrame\", module = \"dask_sql\", subclass)]\n#[derive(Clone)]\npub struct PyWindowFrame {\n    window_frame: WindowFrame,\n}\n\n#[pyclass(name = \"WindowFrameBound\", module = \"dask_sql\", subclass)]\n#[derive(Clone)]\npub struct PyWindowFrameBound {\n    frame_bound: WindowFrameBound,\n}\n\nimpl TryFrom<LogicalPlan> for PyWindow {\n    type Error = PyErr;\n\n    fn try_from(logical_plan: LogicalPlan) -> Result<Self, Self::Error> {\n        match logical_plan {\n            LogicalPlan::Window(window) => Ok(PyWindow { window }),\n            _ => Err(py_type_err(\"unexpected plan\")),\n        }\n    }\n}\n\nimpl From<WindowFrame> for PyWindowFrame {\n    fn from(window_frame: WindowFrame) -> Self {\n        PyWindowFrame { window_frame }\n    }\n}\n\nimpl From<WindowFrameBound> for PyWindowFrameBound {\n    fn from(frame_bound: WindowFrameBound) -> Self {\n        PyWindowFrameBound { frame_bound }\n    }\n}\n\n#[pymethods]\nimpl PyWindow {\n    /// Returns window expressions\n    #[pyo3(name = \"getGroups\")]\n    pub fn get_window_expr(&self) -> PyResult<Vec<PyExpr>> {\n        py_expr_list(&self.window.input, &self.window.window_expr)\n    }\n\n    /// Returns order by columns in a window function expression\n    #[pyo3(name = \"getSortExprs\")]\n    pub fn get_sort_exprs(&self, expr: PyExpr) -> PyResult<Vec<PyExpr>> {\n        match expr.expr.unalias() {\n            Expr::WindowFunction(WindowFunction { order_by, .. }) => {\n                py_expr_list(&self.window.input, &order_by)\n            }\n            other => Err(not_window_function_err(other)),\n        }\n    }\n\n    /// Return partition by columns in a window function expression\n    #[pyo3(name = \"getPartitionExprs\")]\n    pub fn get_partition_exprs(&self, expr: PyExpr) -> PyResult<Vec<PyExpr>> {\n        match expr.expr.unalias() {\n            Expr::WindowFunction(WindowFunction { partition_by, .. }) => {\n                py_expr_list(&self.window.input, &partition_by)\n            }\n            other => Err(not_window_function_err(other)),\n        }\n    }\n\n    /// Return input args for window function\n    #[pyo3(name = \"getArgs\")]\n    pub fn get_args(&self, expr: PyExpr) -> PyResult<Vec<PyExpr>> {\n        match expr.expr.unalias() {\n            Expr::WindowFunction(WindowFunction { args, .. }) => {\n                py_expr_list(&self.window.input, &args)\n            }\n            other => Err(not_window_function_err(other)),\n        }\n    }\n\n    /// Return window function name\n    #[pyo3(name = \"getWindowFuncName\")]\n    pub fn window_func_name(&self, expr: PyExpr) -> PyResult<String> {\n        match expr.expr.unalias() {\n            Expr::WindowFunction(WindowFunction { fun, .. }) => Ok(fun.to_string()),\n            other => Err(not_window_function_err(other)),\n        }\n    }\n\n    /// Returns a Pywindow frame for a given window function expression\n    #[pyo3(name = \"getWindowFrame\")]\n    pub fn get_window_frame(&self, expr: PyExpr) -> Option<PyWindowFrame> {\n        match expr.expr.unalias() {\n            Expr::WindowFunction(WindowFunction { window_frame, .. }) => Some(window_frame.into()),\n            _ => None,\n        }\n    }\n}\n\nfn not_window_function_err(expr: Expr) -> PyErr {\n    py_type_err(format!(\n        \"Provided {} Expr {:?} is not a WindowFunction type\",\n        expr.variant_name(),\n        expr\n    ))\n}\n\n#[pymethods]\nimpl PyWindowFrame {\n    /// Returns the window frame units for the bounds\n    #[pyo3(name = \"getFrameUnit\")]\n    pub fn get_frame_units(&self) -> PyResult<String> {\n        Ok(self.window_frame.units.to_string())\n    }\n    /// Returns starting bound\n    #[pyo3(name = \"getLowerBound\")]\n    pub fn get_lower_bound(&self) -> PyResult<PyWindowFrameBound> {\n        Ok(self.window_frame.start_bound.clone().into())\n    }\n    /// Returns end bound\n    #[pyo3(name = \"getUpperBound\")]\n    pub fn get_upper_bound(&self) -> PyResult<PyWindowFrameBound> {\n        Ok(self.window_frame.end_bound.clone().into())\n    }\n}\n\n#[pymethods]\nimpl PyWindowFrameBound {\n    /// Returns if the frame bound is current row\n    #[pyo3(name = \"isCurrentRow\")]\n    pub fn is_current_row(&self) -> bool {\n        matches!(self.frame_bound, WindowFrameBound::CurrentRow)\n    }\n\n    /// Returns if the frame bound is preceding\n    #[pyo3(name = \"isPreceding\")]\n    pub fn is_preceding(&self) -> bool {\n        matches!(self.frame_bound, WindowFrameBound::Preceding(_))\n    }\n\n    /// Returns if the frame bound is following\n    #[pyo3(name = \"isFollowing\")]\n    pub fn is_following(&self) -> bool {\n        matches!(self.frame_bound, WindowFrameBound::Following(_))\n    }\n    /// Returns the offset of the window frame\n    #[pyo3(name = \"getOffset\")]\n    pub fn get_offset(&self) -> PyResult<Option<u64>> {\n        match &self.frame_bound {\n            WindowFrameBound::Preceding(val) | WindowFrameBound::Following(val) => match val {\n                x if x.is_null() => Ok(None),\n                ScalarValue::UInt64(v) => Ok(*v),\n                // The cast below is only safe because window bounds cannot be negative\n                ScalarValue::Int64(v) => Ok(v.map(|n| n as u64)),\n                ScalarValue::Utf8(v) => {\n                    let s = v.clone().unwrap();\n                    match s.parse::<u64>() {\n                        Ok(s) => Ok(Some(s)),\n                        Err(_e) => Err(DaskPlannerError::Internal(format!(\n                            \"Unable to parse u64 from Utf8 value '{s}'\"\n                        ))\n                        .into()),\n                    }\n                }\n                ref x => Err(DaskPlannerError::Internal(format!(\n                    \"Unexpected window frame bound: {x}\"\n                ))\n                .into()),\n            },\n            WindowFrameBound::CurrentRow => Ok(None),\n        }\n    }\n    /// Returns if the frame bound is unbounded\n    #[pyo3(name = \"isUnbounded\")]\n    pub fn is_unbounded(&self) -> PyResult<bool> {\n        match &self.frame_bound {\n            WindowFrameBound::Preceding(v) | WindowFrameBound::Following(v) => Ok(v.is_null()),\n            WindowFrameBound::CurrentRow => Ok(false),\n        }\n    }\n}\n"
  },
  {
    "path": "src/sql/logical.rs",
    "content": "use crate::sql::{\n    table,\n    types::{rel_data_type::RelDataType, rel_data_type_field::RelDataTypeField},\n};\n\npub mod aggregate;\npub mod alter_schema;\npub mod alter_table;\npub mod analyze_table;\npub mod create_catalog_schema;\npub mod create_experiment;\npub mod create_memory_table;\npub mod create_model;\npub mod create_table;\npub mod describe_model;\npub mod drop_model;\npub mod drop_schema;\npub mod drop_table;\npub mod empty_relation;\npub mod explain;\npub mod export_model;\npub mod filter;\npub mod join;\npub mod limit;\npub mod predict_model;\npub mod projection;\npub mod repartition_by;\npub mod show_columns;\npub mod show_models;\npub mod show_schemas;\npub mod show_tables;\npub mod sort;\npub mod subquery_alias;\npub mod table_scan;\npub mod use_schema;\npub mod window;\n\nuse datafusion_python::{\n    datafusion_common::{DFSchemaRef, DataFusionError},\n    datafusion_expr::{DdlStatement, LogicalPlan},\n};\nuse pyo3::prelude::*;\n\nuse self::{\n    alter_schema::AlterSchemaPlanNode,\n    alter_table::AlterTablePlanNode,\n    analyze_table::AnalyzeTablePlanNode,\n    create_catalog_schema::CreateCatalogSchemaPlanNode,\n    create_experiment::CreateExperimentPlanNode,\n    create_model::CreateModelPlanNode,\n    create_table::CreateTablePlanNode,\n    describe_model::DescribeModelPlanNode,\n    drop_model::DropModelPlanNode,\n    drop_schema::DropSchemaPlanNode,\n    export_model::ExportModelPlanNode,\n    predict_model::PredictModelPlanNode,\n    show_columns::ShowColumnsPlanNode,\n    show_models::ShowModelsPlanNode,\n    show_schemas::ShowSchemasPlanNode,\n    show_tables::ShowTablesPlanNode,\n    use_schema::UseSchemaPlanNode,\n};\nuse crate::{error::Result, sql::exceptions::py_type_err};\n\n#[pyclass(name = \"LogicalPlan\", module = \"dask_sql\", subclass)]\n#[derive(Debug, Clone)]\npub struct PyLogicalPlan {\n    /// The original LogicalPlan that was parsed by DataFusion from the input SQL\n    pub(crate) original_plan: LogicalPlan,\n    /// The original_plan is traversed. current_node stores the current node of this traversal\n    pub(crate) current_node: Option<LogicalPlan>,\n}\n\n/// Unfortunately PyO3 forces us to do this as placing these methods in the #[pymethods] version\n/// of `impl PyLogicalPlan` causes issues with types not properly being mapped to Python from Rust\nimpl PyLogicalPlan {\n    /// Getter method for the LogicalPlan, if current_node is None return original_plan.\n    pub(crate) fn current_node(&mut self) -> LogicalPlan {\n        match &self.current_node {\n            Some(current) => current.clone(),\n            None => {\n                self.current_node = Some(self.original_plan.clone());\n                self.current_node.clone().unwrap()\n            }\n        }\n    }\n}\n\n/// Convert a LogicalPlan to a Python equivalent type\nfn to_py_plan<T: TryFrom<LogicalPlan, Error = PyErr>>(\n    current_node: Option<&LogicalPlan>,\n) -> PyResult<T> {\n    match current_node {\n        Some(plan) => plan.clone().try_into(),\n        _ => Err(py_type_err(\"current_node was None\")),\n    }\n}\n\n#[pymethods]\nimpl PyLogicalPlan {\n    /// LogicalPlan::Aggregate as PyAggregate\n    pub fn aggregate(&self) -> PyResult<aggregate::PyAggregate> {\n        to_py_plan(self.current_node.as_ref())\n    }\n\n    /// LogicalPlan::EmptyRelation as PyEmptyRelation\n    pub fn empty_relation(&self) -> PyResult<empty_relation::PyEmptyRelation> {\n        to_py_plan(self.current_node.as_ref())\n    }\n\n    /// LogicalPlan::Explain as PyExplain\n    pub fn explain(&self) -> PyResult<explain::PyExplain> {\n        to_py_plan(self.current_node.as_ref())\n    }\n\n    /// LogicalPlan::Filter as PyFilter\n    pub fn filter(&self) -> PyResult<filter::PyFilter> {\n        to_py_plan(self.current_node.as_ref())\n    }\n\n    /// LogicalPlan::Join as PyJoin\n    pub fn join(&self) -> PyResult<join::PyJoin> {\n        to_py_plan(self.current_node.as_ref())\n    }\n\n    /// LogicalPlan::Limit as PyLimit\n    pub fn limit(&self) -> PyResult<limit::PyLimit> {\n        to_py_plan(self.current_node.as_ref())\n    }\n\n    /// LogicalPlan::Projection as PyProjection\n    pub fn projection(&self) -> PyResult<projection::PyProjection> {\n        to_py_plan(self.current_node.as_ref())\n    }\n\n    /// LogicalPlan::Sort as PySort\n    pub fn sort(&self) -> PyResult<sort::PySort> {\n        to_py_plan(self.current_node.as_ref())\n    }\n\n    /// LogicalPlan::SubqueryAlias as PySubqueryAlias\n    pub fn subquery_alias(&self) -> PyResult<subquery_alias::PySubqueryAlias> {\n        to_py_plan(self.current_node.as_ref())\n    }\n\n    /// LogicalPlan::Window as PyWindow\n    pub fn window(&self) -> PyResult<window::PyWindow> {\n        to_py_plan(self.current_node.as_ref())\n    }\n\n    /// LogicalPlan::TableScan as PyTableScan\n    pub fn table_scan(&self) -> PyResult<table_scan::PyTableScan> {\n        to_py_plan(self.current_node.as_ref())\n    }\n\n    /// LogicalPlan::CreateMemoryTable as PyCreateMemoryTable\n    pub fn create_memory_table(&self) -> PyResult<create_memory_table::PyCreateMemoryTable> {\n        to_py_plan(self.current_node.as_ref())\n    }\n\n    /// LogicalPlan::CreateModel as PyCreateModel\n    pub fn create_model(&self) -> PyResult<create_model::PyCreateModel> {\n        to_py_plan(self.current_node.as_ref())\n    }\n\n    /// LogicalPlan::CreateExperiment as PyCreateExperiment\n    pub fn create_experiment(&self) -> PyResult<create_experiment::PyCreateExperiment> {\n        to_py_plan(self.current_node.as_ref())\n    }\n\n    /// LogicalPlan::DropTable as DropTable\n    pub fn drop_table(&self) -> PyResult<drop_table::PyDropTable> {\n        to_py_plan(self.current_node.as_ref())\n    }\n\n    /// LogicalPlan::DropModel as DropModel\n    pub fn drop_model(&self) -> PyResult<drop_model::PyDropModel> {\n        to_py_plan(self.current_node.as_ref())\n    }\n\n    /// LogicalPlan::Extension::ShowSchemas as PyShowSchemas\n    pub fn show_schemas(&self) -> PyResult<show_schemas::PyShowSchema> {\n        to_py_plan(self.current_node.as_ref())\n    }\n\n    /// LogicalPlan::Repartition as PyRepartitionBy\n    pub fn repartition_by(&self) -> PyResult<repartition_by::PyRepartitionBy> {\n        to_py_plan(self.current_node.as_ref())\n    }\n\n    /// LogicalPlan::Extension::ShowTables as PyShowTables\n    pub fn show_tables(&self) -> PyResult<show_tables::PyShowTables> {\n        to_py_plan(self.current_node.as_ref())\n    }\n\n    /// LogicalPlan::Extension::CreateTable as PyCreateTable\n    pub fn create_table(&self) -> PyResult<create_table::PyCreateTable> {\n        to_py_plan(self.current_node.as_ref())\n    }\n\n    /// LogicalPlan::Extension::PredictModel as PyPredictModel\n    pub fn predict_model(&self) -> PyResult<predict_model::PyPredictModel> {\n        to_py_plan(self.current_node.as_ref())\n    }\n\n    /// LogicalPlan::Extension::DescribeModel as PyDescribeModel\n    pub fn describe_model(&self) -> PyResult<describe_model::PyDescribeModel> {\n        to_py_plan(self.current_node.as_ref())\n    }\n\n    /// LogicalPlan::Extension::ExportModel as PyExportModel\n    pub fn export_model(&self) -> PyResult<export_model::PyExportModel> {\n        to_py_plan(self.current_node.as_ref())\n    }\n\n    /// LogicalPlan::Extension::ShowColumns as PyShowColumns\n    pub fn show_columns(&self) -> PyResult<show_columns::PyShowColumns> {\n        to_py_plan(self.current_node.as_ref())\n    }\n\n    pub fn show_models(&self) -> PyResult<show_models::PyShowModels> {\n        to_py_plan(self.current_node.as_ref())\n    }\n\n    /// LogicalPlan::Extension::ShowColumns as PyShowColumns\n    pub fn analyze_table(&self) -> PyResult<analyze_table::PyAnalyzeTable> {\n        to_py_plan(self.current_node.as_ref())\n    }\n\n    /// LogicalPlan::CreateCatalogSchema as PyCreateCatalogSchema\n    pub fn create_catalog_schema(&self) -> PyResult<create_catalog_schema::PyCreateCatalogSchema> {\n        to_py_plan(self.current_node.as_ref())\n    }\n\n    /// LogicalPlan::Extension::DropSchema as PyDropSchema\n    pub fn drop_schema(&self) -> PyResult<drop_schema::PyDropSchema> {\n        to_py_plan(self.current_node.as_ref())\n    }\n\n    /// LogicalPlan::Extension::UseSchema as PyUseSchema\n    pub fn use_schema(&self) -> PyResult<use_schema::PyUseSchema> {\n        to_py_plan(self.current_node.as_ref())\n    }\n\n    /// LogicalPlan::Extension::AlterTable as PyAlterTable\n    pub fn alter_table(&self) -> PyResult<alter_table::PyAlterTable> {\n        to_py_plan(self.current_node.as_ref())\n    }\n\n    /// LogicalPlan::Extension::AlterSchema as PyAlterSchema\n    pub fn alter_schema(&self) -> PyResult<alter_schema::PyAlterSchema> {\n        to_py_plan(self.current_node.as_ref())\n    }\n\n    /// Gets the \"input\" for the current LogicalPlan\n    pub fn get_inputs(&mut self) -> PyResult<Vec<PyLogicalPlan>> {\n        let mut py_inputs: Vec<PyLogicalPlan> = Vec::new();\n        for input in self.current_node().inputs() {\n            py_inputs.push(input.clone().into());\n        }\n        Ok(py_inputs)\n    }\n\n    /// If the LogicalPlan represents access to a Table that instance is returned\n    /// otherwise None is returned\n    #[pyo3(name = \"getTable\")]\n    pub fn table(&mut self) -> PyResult<table::DaskTable> {\n        match table::table_from_logical_plan(&self.current_node())? {\n            Some(table) => Ok(table),\n            None => Err(py_type_err(\n                \"Unable to compute DaskTable from DataFusion LogicalPlan\",\n            )),\n        }\n    }\n\n    #[pyo3(name = \"getCurrentNodeSchemaName\")]\n    pub fn get_current_node_schema_name(&self) -> PyResult<&str> {\n        match &self.current_node {\n            Some(e) => {\n                let _sch: &DFSchemaRef = e.schema();\n                //TODO: Where can I actually get this in the context of the running query?\n                Ok(\"root\")\n            }\n            None => Err(py_type_err(DataFusionError::Plan(format!(\n                \"Current schema not found. Defaulting to {:?}\",\n                \"root\"\n            )))),\n        }\n    }\n\n    #[pyo3(name = \"getCurrentNodeTableName\")]\n    pub fn get_current_node_table_name(&mut self) -> PyResult<String> {\n        match self.table() {\n            Ok(dask_table) => Ok(dask_table.table_name),\n            Err(_e) => Err(py_type_err(\"Unable to determine current node table name\")),\n        }\n    }\n\n    /// Gets the Relation \"type\" of the current node. Ex: Projection, TableScan, etc\n    pub fn get_current_node_type(&mut self) -> PyResult<&str> {\n        Ok(match self.current_node() {\n            LogicalPlan::Dml(_) => \"DataManipulationLanguage\",\n            LogicalPlan::DescribeTable(_) => \"DescribeTable\",\n            LogicalPlan::Prepare(_) => \"Prepare\",\n            LogicalPlan::Distinct(_) => \"Distinct\",\n            LogicalPlan::Projection(_projection) => \"Projection\",\n            LogicalPlan::Filter(_filter) => \"Filter\",\n            LogicalPlan::Window(_window) => \"Window\",\n            LogicalPlan::Aggregate(_aggregate) => \"Aggregate\",\n            LogicalPlan::Sort(_sort) => \"Sort\",\n            LogicalPlan::Join(_join) => \"Join\",\n            LogicalPlan::CrossJoin(_cross_join) => \"CrossJoin\",\n            LogicalPlan::Repartition(_repartition) => \"Repartition\",\n            LogicalPlan::Union(_union) => \"Union\",\n            LogicalPlan::TableScan(_table_scan) => \"TableScan\",\n            LogicalPlan::EmptyRelation(_empty_relation) => \"EmptyRelation\",\n            LogicalPlan::Limit(_limit) => \"Limit\",\n            LogicalPlan::Ddl(DdlStatement::CreateExternalTable { .. }) => \"CreateExternalTable\",\n            LogicalPlan::Ddl(DdlStatement::CreateMemoryTable { .. }) => \"CreateMemoryTable\",\n            LogicalPlan::Ddl(DdlStatement::DropTable { .. }) => \"DropTable\",\n            LogicalPlan::Ddl(DdlStatement::DropView { .. }) => \"DropView\",\n            LogicalPlan::Values(_values) => \"Values\",\n            LogicalPlan::Explain(_explain) => \"Explain\",\n            LogicalPlan::Analyze(_analyze) => \"Analyze\",\n            LogicalPlan::Subquery(_sub_query) => \"Subquery\",\n            LogicalPlan::SubqueryAlias(_sqalias) => \"SubqueryAlias\",\n            LogicalPlan::Ddl(DdlStatement::CreateCatalogSchema { .. }) => \"CreateCatalogSchema\",\n            LogicalPlan::Ddl(DdlStatement::DropCatalogSchema { .. }) => \"DropCatalogSchema\",\n            LogicalPlan::Ddl(DdlStatement::CreateCatalog { .. }) => \"CreateCatalog\",\n            LogicalPlan::Ddl(DdlStatement::CreateView { .. }) => \"CreateView\",\n            LogicalPlan::Statement(_) => \"Statement\",\n            // Further examine and return the name that is a possible Dask-SQL Extension type\n            LogicalPlan::Extension(extension) => {\n                let node = extension.node.as_any();\n                if node.downcast_ref::<CreateModelPlanNode>().is_some() {\n                    \"CreateModel\"\n                } else if node.downcast_ref::<CreateExperimentPlanNode>().is_some() {\n                    \"CreateExperiment\"\n                } else if node.downcast_ref::<CreateCatalogSchemaPlanNode>().is_some() {\n                    \"CreateCatalogSchema\"\n                } else if node.downcast_ref::<CreateTablePlanNode>().is_some() {\n                    \"CreateTable\"\n                } else if node.downcast_ref::<DropModelPlanNode>().is_some() {\n                    \"DropModel\"\n                } else if node.downcast_ref::<PredictModelPlanNode>().is_some() {\n                    \"PredictModel\"\n                } else if node.downcast_ref::<ExportModelPlanNode>().is_some() {\n                    \"ExportModel\"\n                } else if node.downcast_ref::<DescribeModelPlanNode>().is_some() {\n                    \"DescribeModel\"\n                } else if node.downcast_ref::<ShowSchemasPlanNode>().is_some() {\n                    \"ShowSchemas\"\n                } else if node.downcast_ref::<ShowTablesPlanNode>().is_some() {\n                    \"ShowTables\"\n                } else if node.downcast_ref::<ShowColumnsPlanNode>().is_some() {\n                    \"ShowColumns\"\n                } else if node.downcast_ref::<ShowModelsPlanNode>().is_some() {\n                    \"ShowModels\"\n                } else if node.downcast_ref::<DropSchemaPlanNode>().is_some() {\n                    \"DropSchema\"\n                } else if node.downcast_ref::<UseSchemaPlanNode>().is_some() {\n                    \"UseSchema\"\n                } else if node.downcast_ref::<AnalyzeTablePlanNode>().is_some() {\n                    \"AnalyzeTable\"\n                } else if node.downcast_ref::<AlterTablePlanNode>().is_some() {\n                    \"AlterTable\"\n                } else if node.downcast_ref::<AlterSchemaPlanNode>().is_some() {\n                    \"AlterSchema\"\n                } else {\n                    // Default to generic `Extension`\n                    \"Extension\"\n                }\n            }\n            LogicalPlan::Unnest(_unnest) => \"Unnest\",\n            LogicalPlan::Copy(_) => \"Copy\",\n        })\n    }\n\n    /// Explain plan for the full and original LogicalPlan\n    pub fn explain_original(&self) -> PyResult<String> {\n        Ok(format!(\"{}\", self.original_plan.display_indent()))\n    }\n\n    /// Explain plan from the current node onward\n    pub fn explain_current(&mut self) -> PyResult<String> {\n        Ok(format!(\"{}\", self.current_node().display_indent()))\n    }\n\n    #[pyo3(name = \"getRowType\")]\n    pub fn row_type(&self) -> PyResult<RelDataType> {\n        match &self.original_plan {\n            LogicalPlan::Join(join) => {\n                let mut lhs_fields: Vec<RelDataTypeField> = join\n                    .left\n                    .schema()\n                    .fields()\n                    .iter()\n                    .map(|f| RelDataTypeField::from(f, join.left.schema().as_ref()))\n                    .collect::<Result<Vec<_>>>()\n                    .map_err(py_type_err)?;\n\n                let mut rhs_fields: Vec<RelDataTypeField> = join\n                    .right\n                    .schema()\n                    .fields()\n                    .iter()\n                    .map(|f| RelDataTypeField::from(f, join.right.schema().as_ref()))\n                    .collect::<Result<Vec<_>>>()\n                    .map_err(py_type_err)?;\n\n                lhs_fields.append(&mut rhs_fields);\n                Ok(RelDataType::new(false, lhs_fields))\n            }\n            LogicalPlan::Distinct(distinct) => {\n                let schema = distinct.input.schema();\n                let rel_fields: Vec<RelDataTypeField> = schema\n                    .fields()\n                    .iter()\n                    .map(|f| RelDataTypeField::from(f, schema.as_ref()))\n                    .collect::<Result<Vec<_>>>()\n                    .map_err(py_type_err)?;\n                Ok(RelDataType::new(false, rel_fields))\n            }\n            _ => {\n                let schema = self.original_plan.schema();\n                let rel_fields: Vec<RelDataTypeField> = schema\n                    .fields()\n                    .iter()\n                    .map(|f| RelDataTypeField::from(f, schema.as_ref()))\n                    .collect::<Result<Vec<_>>>()\n                    .map_err(py_type_err)?;\n\n                Ok(RelDataType::new(false, rel_fields))\n            }\n        }\n    }\n}\n\nimpl From<PyLogicalPlan> for LogicalPlan {\n    fn from(logical_plan: PyLogicalPlan) -> LogicalPlan {\n        logical_plan.original_plan\n    }\n}\n\nimpl From<LogicalPlan> for PyLogicalPlan {\n    fn from(logical_plan: LogicalPlan) -> PyLogicalPlan {\n        PyLogicalPlan {\n            original_plan: logical_plan,\n            current_node: None,\n        }\n    }\n}\n"
  },
  {
    "path": "src/sql/optimizer/decorrelate_where_exists.rs",
    "content": "// Licensed to the Apache Software Foundation (ASF) under one\n// or more contributor license agreements.  See the NOTICE file\n// distributed with this work for additional information\n// regarding copyright ownership.  The ASF licenses this file\n// to you under the Apache License, Version 2.0 (the\n// \"License\"); you may not use this file except in compliance\n// with the License.  You may obtain a copy of the License at\n//\n//   http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing,\n// software distributed under the License is distributed on an\n// \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n// KIND, either express or implied.  See the License for the\n// specific language governing permissions and limitations\n// under the License.\n\nuse std::sync::Arc;\n\nuse datafusion_python::{\n    datafusion_common::{Column, DataFusionError, Result},\n    datafusion_expr::{\n        expr::Exists,\n        logical_plan::{Distinct, Filter, JoinType, Subquery},\n        Expr,\n        LogicalPlan,\n        LogicalPlanBuilder,\n    },\n    datafusion_optimizer::optimizer::{ApplyOrder, OptimizerConfig, OptimizerRule},\n};\n\nuse crate::sql::optimizer::utils::{\n    collect_subquery_cols,\n    conjunction,\n    extract_join_filters,\n    split_conjunction,\n};\n\n/// Optimizer rule for rewriting subquery filters to joins\n#[derive(Default)]\npub struct DecorrelateWhereExists {}\n\nimpl DecorrelateWhereExists {\n    #[allow(missing_docs)]\n    pub fn new() -> Self {\n        Self {}\n    }\n\n    /// Finds expressions that have a where in subquery (and recurse when found)\n    ///\n    /// # Arguments\n    ///\n    /// * `predicate` - A conjunction to split and search\n    /// * `optimizer_config` - For generating unique subquery aliases\n    ///\n    /// Returns a tuple (subqueries, non-subquery expressions)\n    fn extract_subquery_exprs(\n        &self,\n        predicate: &Expr,\n        config: &dyn OptimizerConfig,\n    ) -> Result<(Vec<SubqueryInfo>, Vec<Expr>)> {\n        let filters = split_conjunction(predicate);\n\n        let mut subqueries = vec![];\n        let mut others = vec![];\n        for it in filters.iter() {\n            match it {\n                Expr::Exists(Exists { subquery, negated }) => {\n                    let subquery_plan = self\n                        .try_optimize(&subquery.subquery, config)?\n                        .map(Arc::new)\n                        .unwrap_or_else(|| subquery.subquery.clone());\n                    let new_subquery = subquery.with_plan(subquery_plan);\n                    subqueries.push(SubqueryInfo::new(new_subquery, *negated));\n                }\n                _ => others.push((*it).clone()),\n            }\n        }\n\n        Ok((subqueries, others))\n    }\n}\n\nimpl OptimizerRule for DecorrelateWhereExists {\n    fn try_optimize(\n        &self,\n        plan: &LogicalPlan,\n        config: &dyn OptimizerConfig,\n    ) -> Result<Option<LogicalPlan>> {\n        match plan {\n            LogicalPlan::Filter(filter) => {\n                let (subqueries, other_exprs) =\n                    self.extract_subquery_exprs(&filter.predicate, config)?;\n                if subqueries.is_empty() {\n                    // regular filter, no subquery exists clause here\n                    return Ok(None);\n                }\n\n                // iterate through all exists clauses in predicate, turning each into a join\n                let mut cur_input = filter.input.as_ref().clone();\n                for subquery in subqueries {\n                    if let Some(x) = optimize_exists(&subquery, &cur_input)? {\n                        cur_input = x;\n                    } else {\n                        return Ok(None);\n                    }\n                }\n\n                let expr = conjunction(other_exprs);\n                if let Some(expr) = expr {\n                    let new_filter = Filter::try_new(expr, Arc::new(cur_input))?;\n                    cur_input = LogicalPlan::Filter(new_filter);\n                }\n\n                Ok(Some(cur_input))\n            }\n            _ => Ok(None),\n        }\n    }\n\n    fn name(&self) -> &str {\n        \"decorrelate_where_exists\"\n    }\n\n    fn apply_order(&self) -> Option<ApplyOrder> {\n        Some(ApplyOrder::TopDown)\n    }\n}\n\n/// Takes a query like:\n///\n/// SELECT t1.id\n/// FROM t1\n/// WHERE exists\n/// (\n///    SELECT t2.id FROM t2 WHERE t1.id = t2.id\n/// )\n///\n/// and optimizes it into:\n///\n/// SELECT t1.id\n/// FROM t1 LEFT SEMI\n/// JOIN t2\n/// ON t1.id = t2.id\n///\n/// # Arguments\n///\n/// * query_info - The subquery and negated(exists/not exists) info.\n/// * outer_input - The non-subquery portion (relation t1)\nfn optimize_exists(\n    query_info: &SubqueryInfo,\n    outer_input: &LogicalPlan,\n) -> Result<Option<LogicalPlan>> {\n    let subquery = query_info.query.subquery.as_ref();\n    if let Some((join_filter, optimized_subquery)) = optimize_subquery(subquery)? {\n        // join our sub query into the main plan\n        let join_type = match query_info.negated {\n            true => JoinType::LeftAnti,\n            false => JoinType::LeftSemi,\n        };\n\n        let new_plan = LogicalPlanBuilder::from(outer_input.clone())\n            .join(\n                optimized_subquery,\n                join_type,\n                (Vec::<Column>::new(), Vec::<Column>::new()),\n                Some(join_filter),\n            )?\n            .build()?;\n\n        Ok(Some(new_plan))\n    } else {\n        Ok(None)\n    }\n}\n/// Optimize the subquery and extract the possible join filter.\n/// This function can't optimize non-correlated subquery, and will return None.\nfn optimize_subquery(subquery: &LogicalPlan) -> Result<Option<(Expr, LogicalPlan)>> {\n    match subquery {\n        LogicalPlan::Distinct(subqry_distinct) => {\n            let distinct_input = &subqry_distinct.input;\n            let optimized_plan = optimize_subquery(distinct_input)?.map(|(filters, right)| {\n                (\n                    filters,\n                    LogicalPlan::Distinct(Distinct {\n                        input: Arc::new(right),\n                    }),\n                )\n            });\n            Ok(optimized_plan)\n        }\n        LogicalPlan::Projection(projection) => {\n            // extract join filters\n            let (join_filters, subquery_input) = extract_join_filters(&projection.input)?;\n            // cannot optimize non-correlated subquery\n            if join_filters.is_empty() {\n                return Ok(None);\n            }\n            let input_schema = subquery_input.schema();\n            let project_exprs: Vec<Expr> =\n                collect_subquery_cols(&join_filters, input_schema.clone())?\n                    .into_iter()\n                    .map(Expr::Column)\n                    .collect();\n            let right = LogicalPlanBuilder::from(subquery_input)\n                .project(project_exprs)?\n                .build()?;\n\n            // join_filters is not empty.\n            let join_filter = conjunction(join_filters).ok_or_else(|| {\n                DataFusionError::Internal(\"join filters should not be empty\".to_string())\n            })?;\n            Ok(Some((join_filter, right)))\n        }\n        _ => Ok(None),\n    }\n}\n\nstruct SubqueryInfo {\n    query: Subquery,\n    negated: bool,\n}\n\nimpl SubqueryInfo {\n    pub fn new(query: Subquery, negated: bool) -> Self {\n        Self { query, negated }\n    }\n}\n"
  },
  {
    "path": "src/sql/optimizer/decorrelate_where_in.rs",
    "content": "// Licensed to the Apache Software Foundation (ASF) under one\n// or more contributor license agreements.  See the NOTICE file\n// distributed with this work for additional information\n// regarding copyright ownership.  The ASF licenses this file\n// to you under the Apache License, Version 2.0 (the\n// \"License\"); you may not use this file except in compliance\n// with the License.  You may obtain a copy of the License at\n//\n//   http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing,\n// software distributed under the License is distributed on an\n// \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n// KIND, either express or implied.  See the License for the\n// specific language governing permissions and limitations\n// under the License.\n\nuse std::sync::Arc;\n\nuse datafusion_python::{\n    datafusion_common::{alias::AliasGenerator, context, Column, DataFusionError, Result},\n    datafusion_expr::{\n        expr::InSubquery,\n        expr_rewriter::unnormalize_col,\n        logical_plan::{JoinType, Projection, Subquery},\n        Expr,\n        Filter,\n        LogicalPlan,\n        LogicalPlanBuilder,\n    },\n    datafusion_optimizer::optimizer::{ApplyOrder, OptimizerConfig, OptimizerRule},\n};\nuse log::debug;\n\nuse crate::sql::optimizer::utils::{\n    collect_subquery_cols,\n    conjunction,\n    extract_join_filters,\n    only_or_err,\n    replace_qualified_name,\n    split_conjunction,\n};\n\n#[derive(Default)]\npub struct DecorrelateWhereIn {\n    alias: AliasGenerator,\n}\n\nimpl DecorrelateWhereIn {\n    #[allow(missing_docs)]\n    pub fn new() -> Self {\n        Self::default()\n    }\n\n    /// Finds expressions that have a where in subquery (and recurses when found)\n    ///\n    /// # Arguments\n    ///\n    /// * `predicate` - A conjunction to split and search\n    /// * `optimizer_config` - For generating unique subquery aliases\n    ///\n    /// Returns a tuple (subqueries, non-subquery expressions)\n    fn extract_subquery_exprs(\n        &self,\n        predicate: &Expr,\n        config: &dyn OptimizerConfig,\n    ) -> Result<(Vec<SubqueryInfo>, Vec<Expr>)> {\n        let filters = split_conjunction(predicate); // TODO: disjunctions\n\n        let mut subqueries = vec![];\n        let mut others = vec![];\n        for it in filters.iter() {\n            match it {\n                Expr::InSubquery(InSubquery {\n                    expr,\n                    subquery,\n                    negated,\n                }) => {\n                    let subquery_plan = self\n                        .try_optimize(&subquery.subquery, config)?\n                        .map(Arc::new)\n                        .unwrap_or_else(|| subquery.subquery.clone());\n                    let new_subquery = subquery.with_plan(subquery_plan);\n                    subqueries.push(SubqueryInfo::new(new_subquery, (**expr).clone(), *negated));\n                    // TODO: if subquery doesn't get optimized, optimized children are lost\n                }\n                _ => others.push((*it).clone()),\n            }\n        }\n\n        Ok((subqueries, others))\n    }\n}\n\nimpl OptimizerRule for DecorrelateWhereIn {\n    fn try_optimize(\n        &self,\n        plan: &LogicalPlan,\n        config: &dyn OptimizerConfig,\n    ) -> Result<Option<LogicalPlan>> {\n        match plan {\n            LogicalPlan::Filter(filter) => {\n                let (subqueries, other_exprs) =\n                    self.extract_subquery_exprs(&filter.predicate, config)?;\n                if subqueries.is_empty() {\n                    // regular filter, no subquery exists clause here\n                    return Ok(None);\n                }\n\n                // iterate through all exists clauses in predicate, turning each into a join\n                let mut cur_input = filter.input.as_ref().clone();\n                for subquery in subqueries {\n                    cur_input = optimize_where_in(&subquery, &cur_input, &self.alias)?;\n                }\n\n                let expr = conjunction(other_exprs);\n                if let Some(expr) = expr {\n                    let new_filter = Filter::try_new(expr, Arc::new(cur_input))?;\n                    cur_input = LogicalPlan::Filter(new_filter);\n                }\n\n                Ok(Some(cur_input))\n            }\n            _ => Ok(None),\n        }\n    }\n\n    fn name(&self) -> &str {\n        \"decorrelate_where_in\"\n    }\n\n    fn apply_order(&self) -> Option<ApplyOrder> {\n        Some(ApplyOrder::TopDown)\n    }\n}\n\n/// Optimize the where in subquery to left-anti/left-semi join.\n/// If the subquery is a correlated subquery, we need extract the join predicate from the subquery.\n///\n/// For example, given a query like:\n/// `select t1.a, t1.b from t1 where t1 in (select t2.a from t2 where t1.b = t2.b and t1.c > t2.c)`\n///\n/// The optimized plan will be:\n///\n/// ```text\n/// Projection: t1.a, t1.b\n///   LeftSemi Join:  Filter: t1.a = __correlated_sq_1.a AND t1.b = __correlated_sq_1.b AND t1.c > __correlated_sq_1.c\n///     TableScan: t1\n///     SubqueryAlias: __correlated_sq_1\n///       Projection: t2.a AS a, t2.b, t2.c\n///         TableScan: t2\n/// ```\nfn optimize_where_in(\n    query_info: &SubqueryInfo,\n    left: &LogicalPlan,\n    alias: &AliasGenerator,\n) -> Result<LogicalPlan> {\n    let projection = try_from_plan(&query_info.query.subquery)\n        .map_err(|e| context!(\"a projection is required\", e))?;\n    let subquery_input = projection.input.clone();\n    // TODO add the validate logic to Analyzer\n    let subquery_expr = only_or_err(projection.expr.as_slice())\n        .map_err(|e| context!(\"single expression projection required\", e))?;\n\n    // extract join filters\n    let (join_filters, subquery_input) = extract_join_filters(subquery_input.as_ref())?;\n\n    // in_predicate may be also include in the join filters, remove it from the join filters.\n    let in_predicate = Expr::eq(query_info.where_in_expr.clone(), subquery_expr.clone());\n    let join_filters = remove_duplicated_filter(join_filters, in_predicate);\n\n    // replace qualified name with subquery alias.\n    let subquery_alias = alias.next(\"__correlated_sq\");\n    let input_schema = subquery_input.schema();\n    let mut subquery_cols = collect_subquery_cols(&join_filters, input_schema.clone())?;\n    let join_filter = conjunction(join_filters).map_or(Ok(None), |filter| {\n        replace_qualified_name(filter, &subquery_cols, &subquery_alias).map(Option::Some)\n    })?;\n\n    // add projection\n    if let Expr::Column(col) = subquery_expr {\n        subquery_cols.remove(col);\n    }\n    let subquery_expr_name = format!(\"{:?}\", unnormalize_col(subquery_expr.clone()));\n    let first_expr = subquery_expr.clone().alias(subquery_expr_name.clone());\n    let projection_exprs: Vec<Expr> = [first_expr]\n        .into_iter()\n        .chain(subquery_cols.into_iter().map(Expr::Column))\n        .collect();\n\n    let right = LogicalPlanBuilder::from(subquery_input)\n        .project(projection_exprs)?\n        .alias(subquery_alias.clone())?\n        .build()?;\n\n    // join our sub query into the main plan\n    let join_type = match query_info.negated {\n        true => JoinType::LeftAnti,\n        false => JoinType::LeftSemi,\n    };\n    let right_join_col = Column::new(Some(subquery_alias), subquery_expr_name);\n    let in_predicate = Expr::eq(\n        query_info.where_in_expr.clone(),\n        Expr::Column(right_join_col),\n    );\n    let join_filter = join_filter\n        .map(|filter| in_predicate.clone().and(filter))\n        .unwrap_or_else(|| in_predicate);\n\n    let new_plan = LogicalPlanBuilder::from(left.clone())\n        .join(\n            right,\n            join_type,\n            (Vec::<Column>::new(), Vec::<Column>::new()),\n            Some(join_filter),\n        )?\n        .build()?;\n\n    debug!(\"where in optimized:\\n{}\", new_plan.display_indent());\n    Ok(new_plan)\n}\n\nfn remove_duplicated_filter(filters: Vec<Expr>, in_predicate: Expr) -> Vec<Expr> {\n    filters\n        .into_iter()\n        .filter(|filter| {\n            if filter == &in_predicate {\n                return false;\n            }\n\n            // ignore the binary order\n            !match (filter, &in_predicate) {\n                (Expr::BinaryExpr(a_expr), Expr::BinaryExpr(b_expr)) => {\n                    (a_expr.op == b_expr.op)\n                        && (a_expr.left == b_expr.left && a_expr.right == b_expr.right)\n                        || (a_expr.left == b_expr.right && a_expr.right == b_expr.left)\n                }\n                _ => false,\n            }\n        })\n        .collect::<Vec<_>>()\n}\n\nfn try_from_plan(plan: &LogicalPlan) -> Result<&Projection> {\n    match plan {\n        LogicalPlan::Projection(it) => Ok(it),\n        _ => Err(DataFusionError::Internal(\n            \"Could not coerce into Projection!\".to_string(),\n        )),\n    }\n}\n\nstruct SubqueryInfo {\n    query: Subquery,\n    where_in_expr: Expr,\n    negated: bool,\n}\n\nimpl SubqueryInfo {\n    pub fn new(query: Subquery, expr: Expr, negated: bool) -> Self {\n        Self {\n            query,\n            where_in_expr: expr,\n            negated,\n        }\n    }\n}\n"
  },
  {
    "path": "src/sql/optimizer/dynamic_partition_pruning.rs",
    "content": "//! Optimizer rule for dynamic partition pruning (DPP)\n//!\n//! DPP refers to a query optimization rule in which distinct values in an inner join are used as\n//! filters in a table scan. This allows us to eliminate all other rows which do not fit the join\n//! condition from being read at all.\n//!\n//! Furthermore, a table involved in a join may be filtered during a scan, which allows us to\n//! further prune the values to be read.\n\nuse std::{\n    collections::{HashMap, HashSet},\n    fs,\n    hash::{Hash, Hasher},\n};\n\nuse datafusion_python::{\n    datafusion::parquet::{\n        basic::Type as BasicType,\n        file::reader::{FileReader, SerializedFileReader},\n        record::{reader::RowIter, RowAccessor},\n        schema::{parser::parse_message_type, types::Type},\n    },\n    datafusion_common::{Column, Result, ScalarValue},\n    datafusion_expr::{\n        expr::InList,\n        logical_plan::LogicalPlan,\n        Expr,\n        JoinType,\n        Operator,\n        TableScan,\n    },\n    datafusion_optimizer::{OptimizerConfig, OptimizerRule},\n};\nuse log::warn;\n\nuse crate::sql::table::DaskTableSource;\n\n// Optimizer rule for dynamic partition pruning\npub struct DynamicPartitionPruning {\n    /// Ratio of the size of the dimension tables to fact tables\n    fact_dimension_ratio: f64,\n}\n\nimpl DynamicPartitionPruning {\n    pub fn new(fact_dimension_ratio: f64) -> Self {\n        Self {\n            fact_dimension_ratio,\n        }\n    }\n}\n\nimpl OptimizerRule for DynamicPartitionPruning {\n    fn name(&self) -> &str {\n        \"dynamic_partition_pruning\"\n    }\n\n    fn try_optimize(\n        &self,\n        plan: &LogicalPlan,\n        _config: &dyn OptimizerConfig,\n    ) -> Result<Option<LogicalPlan>> {\n        // Parse the LogicalPlan and store tables and columns being (inner) joined upon. We do this\n        // by creating a HashSet of all InnerJoins' join.on and join.filters\n        let join_conds = gather_joins(plan);\n        let tables = gather_tables(plan);\n        let aliases = gather_aliases(plan);\n\n        if join_conds.is_empty() || tables.is_empty() {\n            // No InnerJoins to optimize with\n            Ok(None)\n        } else {\n            // Find the size of the largest table in the query\n            let mut largest_size = 1_f64;\n            for table in &tables {\n                let table_size = table.1.size.unwrap_or(0) as f64;\n                if table_size > largest_size {\n                    largest_size = table_size;\n                }\n            }\n\n            let mut join_values = vec![];\n            let mut join_tables = vec![];\n            let mut join_fields = vec![];\n            let mut fact_tables = HashSet::new();\n\n            // Iterate through all inner joins in the query\n            for join_cond in &join_conds {\n                let join_on = &join_cond.on;\n                for on_i in join_on {\n                    // Obtain tables and columns (fields) involved in join\n                    let (left_on, right_on) = (&on_i.0, &on_i.1);\n                    let (mut left_table, mut right_table) = (None, None);\n                    let (mut left_field, mut right_field) = (None, None);\n\n                    if let Expr::Column(c) = left_on {\n                        left_table = Some(c.relation.clone().unwrap().to_string().clone());\n                        left_field = Some(c.name.clone());\n                    }\n                    if let Expr::Column(c) = right_on {\n                        right_table = Some(c.relation.clone().unwrap().to_string().clone());\n                        right_field = Some(c.name.clone());\n                    }\n\n                    // For now, if it is not a join between columns then we skip the rule\n                    // TODO: https://github.com/dask-contrib/dask-sql/issues/1121\n                    if left_table.is_none() || right_table.is_none() {\n                        continue;\n                    }\n\n                    let (mut left_table, mut right_table) =\n                        (left_table.unwrap(), right_table.unwrap());\n                    let (left_field, right_field) = (left_field.unwrap(), right_field.unwrap());\n\n                    let (mut left_filtered_table, mut right_filtered_table) = (None, None);\n\n                    // Check if join uses an alias instead of the table name itself. Need to use\n                    // the actual table name to obtain its filepath\n                    let left_alias = aliases.get(&left_table.clone());\n                    if let Some(t) = left_alias {\n                        left_table = t.to_string()\n                    }\n                    let right_alias = aliases.get(&right_table.clone());\n                    if let Some(t) = right_alias {\n                        right_table = t.to_string()\n                    }\n\n                    // A more complicated alias, e.g. an alias for a nested select, means it's not\n                    // obvious which file(s) should be read\n                    if !tables.contains_key(&left_table) || !tables.contains_key(&right_table) {\n                        continue;\n                    }\n\n                    // Determine whether a table is a fact or dimension table. If it's a dimension\n                    // table, we should read it in and use the rule\n                    if tables\n                        .get(&left_table.clone())\n                        .unwrap()\n                        .size\n                        .unwrap_or(largest_size as usize) as f64\n                        / largest_size\n                        < self.fact_dimension_ratio\n                    {\n                        left_filtered_table =\n                            read_table(left_table.clone(), left_field.clone(), tables.clone());\n                    } else {\n                        fact_tables.insert(left_table.clone());\n                    }\n                    if tables\n                        .get(&right_table.clone())\n                        .unwrap()\n                        .size\n                        .unwrap_or(largest_size as usize) as f64\n                        / largest_size\n                        < self.fact_dimension_ratio\n                    {\n                        right_filtered_table =\n                            read_table(right_table.clone(), right_field.clone(), tables.clone());\n                    } else {\n                        fact_tables.insert(right_table.clone());\n                    }\n\n                    join_values.push((left_filtered_table, right_filtered_table));\n                    join_tables.push((left_table, right_table));\n                    join_fields.push((left_field, right_field));\n                }\n            }\n            // Creates HashMap of all tables and field with their unique values to be set in the\n            // TableScan\n            let filter_values = combine_sets(join_values, join_tables, join_fields, fact_tables);\n            // Optimize and return the plan\n            optimize_table_scans(plan, filter_values)\n        }\n    }\n}\n\n/// Represents relevant information in an InnerJoin\n#[derive(Clone, Debug, Eq, Hash, PartialEq)]\nstruct JoinInfo {\n    /// Equijoin clause expressed as pairs of (left, right) join expressions\n    on: Vec<(Expr, Expr)>,\n    /// Filters applied during join (non-equi conditions)\n    /// TODO: https://github.com/dask-contrib/dask-sql/issues/1121\n    filter: Option<Expr>,\n}\n\n// This function parses through the LogicalPlan, grabs relevant information from an InnerJoin, and\n// adds them to a HashSet\nfn gather_joins(plan: &LogicalPlan) -> HashSet<JoinInfo> {\n    let mut current_plan = plan.clone();\n    let mut join_info = HashSet::new();\n    loop {\n        if current_plan.inputs().is_empty() {\n            break;\n        } else if current_plan.inputs().len() > 1 {\n            match current_plan {\n                LogicalPlan::Join(ref j) => {\n                    if j.join_type == JoinType::Inner {\n                        // Store tables and columns that are being (inner) joined upon\n                        let info = JoinInfo {\n                            on: j.on.clone(),\n                            filter: j.filter.clone(),\n                        };\n                        join_info.insert(info);\n\n                        // Recurse on left and right inputs of Join\n                        let (left_joins, right_joins) =\n                            (gather_joins(&j.left), gather_joins(&j.right));\n\n                        // Add left_joins and right_joins to HashSet\n                        join_info.extend(left_joins);\n                        join_info.extend(right_joins);\n                    } else {\n                        // We don't run the rule if there are non-inner joins in the query\n                        return HashSet::new();\n                    }\n                }\n                LogicalPlan::CrossJoin(ref c) => {\n                    // Recurse on left and right inputs of CrossJoin\n                    let (left_joins, right_joins) = (gather_joins(&c.left), gather_joins(&c.right));\n\n                    // Add left_joins and right_joins to HashSet\n                    join_info.extend(left_joins);\n                    join_info.extend(right_joins);\n                }\n                LogicalPlan::Union(ref u) => {\n                    // Recurse on inputs vector of Union\n                    for input in &u.inputs {\n                        let joins = gather_joins(input);\n\n                        // Add joins to HashSet\n                        join_info.extend(joins);\n                    }\n                }\n                _ => {\n                    warn!(\"Skipping optimizer rule 'DynamicPartitionPruning'\");\n                    return HashSet::new();\n                }\n            }\n            break;\n        } else {\n            // Move on to next step\n            current_plan = current_plan.inputs()[0].clone();\n        }\n    }\n    join_info\n}\n\n/// Represents relevant information in a TableScan\n#[derive(Clone, Debug, Eq, Hash, PartialEq)]\nstruct TableInfo {\n    /// The name of the table\n    table_name: String,\n    /// The path and filename of the table\n    filepath: String,\n    /// The number of rows in the table\n    size: Option<usize>,\n    /// Optional expressions to be used as filters by the table provider\n    filters: Vec<Expr>,\n}\n\n// This function parses through the LogicalPlan, grabs relevant information from a TableScan, and\n// adds them to a HashMap where the key is the table name\nfn gather_tables(plan: &LogicalPlan) -> HashMap<String, TableInfo> {\n    let mut current_plan = plan.clone();\n    let mut tables = HashMap::new();\n    loop {\n        if current_plan.inputs().is_empty() {\n            if let LogicalPlan::TableScan(ref t) = current_plan {\n                // Use TableScan to get the filepath and/or size\n                let filepath = get_filepath(&current_plan);\n                let size = get_table_size(&current_plan);\n                match filepath {\n                    Some(f) => {\n                        // TODO: Add better handling for when a table is read in more than once\n                        // https://github.com/dask-contrib/dask-sql/issues/1121\n                        if tables.contains_key(&t.table_name.to_string()) {\n                            return HashMap::new();\n                        }\n\n                        tables.insert(\n                            t.table_name.to_string(),\n                            TableInfo {\n                                table_name: t.table_name.to_string(),\n                                filepath: f.clone(),\n                                size,\n                                filters: t.filters.clone(),\n                            },\n                        );\n                        break;\n                    }\n                    None => return HashMap::new(),\n                }\n            }\n            break;\n        } else if current_plan.inputs().len() > 1 {\n            match current_plan {\n                LogicalPlan::Join(ref j) => {\n                    // Recurse on left and right inputs of Join\n                    let (left_tables, right_tables) =\n                        (gather_tables(&j.left), gather_tables(&j.right));\n\n                    if check_table_overlaps(&tables, &left_tables, &right_tables) {\n                        return HashMap::new();\n                    }\n\n                    // Add left_tables and right_tables to HashMap\n                    tables.extend(left_tables);\n                    tables.extend(right_tables);\n                }\n                LogicalPlan::CrossJoin(ref c) => {\n                    // Recurse on left and right inputs of CrossJoin\n                    let (left_tables, right_tables) =\n                        (gather_tables(&c.left), gather_tables(&c.right));\n\n                    if check_table_overlaps(&tables, &left_tables, &right_tables) {\n                        return HashMap::new();\n                    }\n\n                    // Add left_tables and right_tables to HashMap\n                    tables.extend(left_tables);\n                    tables.extend(right_tables);\n                }\n                LogicalPlan::Union(ref u) => {\n                    // Recurse on inputs vector of Union\n                    for input in &u.inputs {\n                        let union_tables = gather_tables(input);\n\n                        // TODO: Add better handling for when a table is read in more than once\n                        // https://github.com/dask-contrib/dask-sql/issues/1121\n                        if tables.keys().any(|k| union_tables.contains_key(k))\n                            || union_tables.keys().any(|k| tables.contains_key(k))\n                        {\n                            return HashMap::new();\n                        }\n\n                        // Add union_tables to HashMap\n                        tables.extend(union_tables);\n                    }\n                }\n                _ => {\n                    warn!(\"Skipping optimizer rule 'DynamicPartitionPruning'\");\n                    return HashMap::new();\n                }\n            }\n            break;\n        } else {\n            // Move on to next step\n            current_plan = current_plan.inputs()[0].clone();\n        }\n    }\n    tables\n}\n\n// TODO: Add better handling for when a table is read in more than once\n// https://github.com/dask-contrib/dask-sql/issues/1121\nfn check_table_overlaps(\n    m1: &HashMap<String, TableInfo>,\n    m2: &HashMap<String, TableInfo>,\n    m3: &HashMap<String, TableInfo>,\n) -> bool {\n    m1.keys().any(|k| m2.contains_key(k))\n        || m2.keys().any(|k| m1.contains_key(k))\n        || m1.keys().any(|k| m3.contains_key(k))\n        || m3.keys().any(|k| m1.contains_key(k))\n        || m2.keys().any(|k| m3.contains_key(k))\n        || m3.keys().any(|k| m2.contains_key(k))\n}\n\nfn get_filepath(plan: &LogicalPlan) -> Option<&String> {\n    match plan {\n        LogicalPlan::TableScan(scan) => scan\n            .source\n            .as_any()\n            .downcast_ref::<DaskTableSource>()?\n            .filepath(),\n        _ => None,\n    }\n}\n\nfn get_table_size(plan: &LogicalPlan) -> Option<usize> {\n    match plan {\n        LogicalPlan::TableScan(scan) => scan\n            .source\n            .as_any()\n            .downcast_ref::<DaskTableSource>()?\n            .statistics()\n            .map(|stats| stats.get_row_count() as usize),\n        _ => None,\n    }\n}\n\n// This function parses through the LogicalPlan, grabs any aliases, and adds them to a HashMap\n// where the key is the alias name and the value is the table name\nfn gather_aliases(plan: &LogicalPlan) -> HashMap<String, String> {\n    let mut current_plan = plan.clone();\n    let mut aliases = HashMap::new();\n    loop {\n        if current_plan.inputs().is_empty() {\n            break;\n        } else if current_plan.inputs().len() > 1 {\n            match current_plan {\n                LogicalPlan::Join(ref j) => {\n                    // Recurse on left and right inputs of Join\n                    let (left_aliases, right_aliases) =\n                        (gather_aliases(&j.left), gather_aliases(&j.right));\n\n                    // Add left_aliases and right_aliases to HashMap\n                    aliases.extend(left_aliases);\n                    aliases.extend(right_aliases);\n                }\n                LogicalPlan::CrossJoin(ref c) => {\n                    // Recurse on left and right inputs of CrossJoin\n                    let (left_aliases, right_aliases) =\n                        (gather_aliases(&c.left), gather_aliases(&c.right));\n\n                    // Add left_aliases and right_aliases to HashMap\n                    aliases.extend(left_aliases);\n                    aliases.extend(right_aliases);\n                }\n                LogicalPlan::Union(ref u) => {\n                    // Recurse on inputs vector of Union\n                    for input in &u.inputs {\n                        let union_aliases = gather_aliases(input);\n\n                        // Add union_aliases to HashMap\n                        aliases.extend(union_aliases);\n                    }\n                }\n                _ => {\n                    return HashMap::new();\n                }\n            }\n            break;\n        } else {\n            if let LogicalPlan::SubqueryAlias(ref s) = current_plan {\n                match *s.input {\n                    LogicalPlan::TableScan(ref t) => {\n                        aliases.insert(s.alias.to_string(), t.table_name.to_string().clone());\n                    }\n                    // Sometimes a TableScan is immediately followed by a Projection, so we can\n                    // still use the alias for the table\n                    LogicalPlan::Projection(ref p) => {\n                        if let LogicalPlan::TableScan(ref t) = *p.input {\n                            aliases.insert(s.alias.to_string(), t.table_name.to_string().clone());\n                        }\n                    }\n                    _ => (),\n                }\n            }\n            // Move on to next step\n            current_plan = current_plan.inputs()[0].clone();\n        }\n    }\n    aliases\n}\n\n// Wrapper for floats, since they are not hashable\n#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]\nstruct FloatWrapper(f64);\n\nimpl Eq for FloatWrapper {}\n\nimpl Hash for FloatWrapper {\n    fn hash<H: Hasher>(&self, state: &mut H) {\n        // Convert the f64 to a u64 using transmute\n        let bits: u64 = self.0.to_bits();\n        // Use the u64's hash implementation\n        bits.hash(state);\n    }\n}\n\n// Wrapper for possible row value types\n#[derive(Clone, Debug, Eq, Hash, PartialEq)]\nenum RowValue {\n    String(Option<String>),\n    Int64(Option<i64>),\n    Int32(Option<i32>),\n    Double(Option<FloatWrapper>),\n}\n\n// This function uses the table name, column name, and filters to read in the relevant columns,\n// filter out row values, and construct a HashSet of relevant row values for the specified column,\n// i.e., the column involved in the join\nfn read_table(\n    table_string: String,\n    field_string: String,\n    tables: HashMap<String, TableInfo>,\n) -> Option<HashSet<RowValue>> {\n    let file_path = tables.get(&table_string).unwrap().filepath.clone();\n    let paths: fs::ReadDir;\n    let mut files = vec![];\n    if fs::metadata(&file_path)\n        .map(|metadata| metadata.is_dir())\n        .unwrap_or(false)\n    {\n        // Obtain filepaths to all relevant Parquet files, e.g., in a directory of Parquet files\n        paths = fs::read_dir(&file_path).unwrap();\n        for path in paths {\n            files.push(path.unwrap().path().display().to_string())\n        }\n    } else {\n        // Obtain single Parquet file\n        files.push(file_path);\n    }\n\n    // Using the filepaths to the Parquet tables, obtain the schemas of the relevant tables\n    let schema: &Type = &SerializedFileReader::try_from(files[0].clone())\n        .unwrap()\n        .metadata()\n        .file_metadata()\n        .schema()\n        .clone();\n\n    // Use the schemas of the relevant tables to obtain the physical type of the relevant columns\n    let physical_type = get_physical_type(schema, field_string.clone());\n\n    // A TableScan may include existing filters. These conditions should be used to filter the data\n    // after being read. Therefore, the columns involved in these filters should be read in as well\n    let filters = tables.get(&table_string).unwrap().filters.clone();\n    let filtered_fields = get_filtered_fields(&filters, schema, field_string.clone());\n    let filtered_string = filtered_fields.0;\n    let filtered_types = filtered_fields.1;\n    let filtered_names = filtered_fields.2;\n\n    if filters.len() != filtered_names.len() {\n        warn!(\"Unable to check existing filters for optimizer rule 'DynamicPartitionPruning'\");\n        return None;\n    }\n\n    // Specify which columns to include in the reader, then read in the rows\n    let repetition = get_repetition(schema, field_string.clone());\n    let physical_type = physical_type.unwrap().to_string();\n    let projection_schema = \"message schema { \".to_owned()\n        + &filtered_string\n        + &repetition.unwrap()\n        + \" \"\n        + &physical_type\n        + \" \"\n        + &field_string\n        + \"; }\";\n    let projection = parse_message_type(&projection_schema).ok();\n\n    let mut rows = Vec::new();\n    for file in files {\n        let reader_result = SerializedFileReader::try_from(&*file.clone());\n        if let Ok(reader) = reader_result {\n            let row_iter_result = RowIter::from_file_into(Box::new(reader))\n                .project(projection.clone())\n                .ok();\n            if let Some(row_iter) = row_iter_result {\n                rows.extend(row_iter.map(|r| r.expect(\"Parquet error encountered\")));\n            } else {\n                // TODO: Investigate cases when this would happen\n                rows.clear();\n                break;\n            }\n        } else {\n            rows.clear();\n            break;\n        }\n    }\n    if rows.is_empty() {\n        return None;\n    }\n\n    // Create HashSets for the join column values\n    let mut value_set: HashSet<RowValue> = HashSet::new();\n    for row in rows {\n        // Since a TableScan may have its own filters, we want to ensure that the values in\n        // value_set satisfy the TableScan filters\n        let mut satisfies_filters = true;\n        let mut row_index = 0;\n        for index in 0..filters.len() {\n            if filtered_names[index] != field_string {\n                let current_type = &filtered_types[index];\n                match current_type.as_str() {\n                    \"BYTE_ARRAY\" => {\n                        let string_value = row.get_string(row_index).ok();\n                        if !satisfies_string(string_value, filters[index].clone()) {\n                            satisfies_filters = false;\n                        }\n                    }\n                    \"INT64\" => {\n                        let long_value = row.get_long(row_index).ok();\n                        if !satisfies_int64(long_value, filters[index].clone()) {\n                            satisfies_filters = false;\n                        }\n                    }\n                    \"INT32\" => {\n                        let int_value = row.get_int(row_index).ok();\n                        if !satisfies_int32(int_value, filters[index].clone()) {\n                            satisfies_filters = false;\n                        }\n                    }\n                    \"DOUBLE\" => {\n                        let double_value = row.get_double(row_index).ok();\n                        if !satisfies_float(double_value, filters[index].clone()) {\n                            satisfies_filters = false;\n                        }\n                    }\n                    u => panic!(\"Unknown PhysicalType {u}\"),\n                }\n                row_index += 1;\n            }\n        }\n        // After verifying that the row satisfies all existing filters, we add the column value to\n        // the HashSet\n        if satisfies_filters {\n            match physical_type.as_str() {\n                \"BYTE_ARRAY\" => {\n                    let r = row.get_string(row_index).ok();\n                    value_set.insert(RowValue::String(r.cloned()));\n                }\n                \"INT64\" => {\n                    let r = row.get_long(row_index).ok();\n                    value_set.insert(RowValue::Int64(r));\n                }\n                \"INT32\" => {\n                    let r = row.get_int(row_index).ok();\n                    value_set.insert(RowValue::Int32(r));\n                }\n                \"DOUBLE\" => {\n                    let r = row.get_double(row_index).ok();\n                    if let Some(f) = r {\n                        value_set.insert(RowValue::Double(Some(FloatWrapper(f))));\n                    } else {\n                        value_set.insert(RowValue::Double(None));\n                    }\n                }\n                _ => panic!(\"Unknown PhysicalType\"),\n            }\n        }\n    }\n\n    Some(value_set)\n}\n\n// A column has a physical_type (INT64, etc.) that needs to be included when specifying which\n// columns to read in. To get the physical_type, we grab it from the schema\nfn get_physical_type(schema: &Type, field: String) -> Option<BasicType> {\n    match schema {\n        Type::GroupType {\n            basic_info: _,\n            fields,\n        } => {\n            for f in fields {\n                let match_field = &*f.clone();\n                match match_field {\n                    Type::PrimitiveType {\n                        basic_info,\n                        physical_type,\n                        ..\n                    } => {\n                        if basic_info.name() == field {\n                            return Some(*physical_type);\n                        }\n                    }\n                    _ => return None,\n                }\n            }\n            None\n        }\n        _ => None,\n    }\n}\n\n// A column has a repetition (i.e., REQUIRED or OPTIONAL) that needs to be included when specifying\n// which columns to read in. To get the repetition, we grab it from the schema\nfn get_repetition(schema: &Type, field: String) -> Option<String> {\n    match schema {\n        Type::GroupType {\n            basic_info: _,\n            fields,\n        } => {\n            for f in fields {\n                let match_field = &*f.clone();\n                match match_field {\n                    Type::PrimitiveType { basic_info, .. } => {\n                        if basic_info.name() == field {\n                            return Some(basic_info.repetition().to_string());\n                        }\n                    }\n                    _ => return None,\n                }\n            }\n            None\n        }\n        _ => None,\n    }\n}\n\n// This is a helper function to deal with TableScan filters for reading in the data. The first\n// value returned is a string representation of the projection used to read in the relevant\n// columns. The second value returned is a vector of the physical_type of each column that has has\n// a filter, in the order that they are being read. The third value returned is a vector of the\n// column names, in the order that they are being read.\nfn get_filtered_fields(\n    filters: &Vec<Expr>,\n    schema: &Type,\n    field: String,\n) -> (String, Vec<String>, Vec<String>) {\n    // Used to create a string representation of the projection\n    // for the TableScan filters to be read\n    let mut filtered_fields = vec![];\n    // All physical types involved in TableScan filters\n    let mut filtered_types = vec![];\n    // All columns involved in TableScan filters\n    let mut filtered_columns = vec![];\n    for filter in filters {\n        match filter {\n            Expr::BinaryExpr(b) => {\n                if let Expr::Column(column) = &*b.left {\n                    push_filtered_fields(\n                        column,\n                        schema,\n                        field.clone(),\n                        &mut filtered_fields,\n                        &mut filtered_columns,\n                        &mut filtered_types,\n                    );\n                }\n            }\n            Expr::IsNotNull(e) => {\n                if let Expr::Column(column) = &**e {\n                    push_filtered_fields(\n                        column,\n                        schema,\n                        field.clone(),\n                        &mut filtered_fields,\n                        &mut filtered_columns,\n                        &mut filtered_types,\n                    );\n                }\n            }\n            _ => (),\n        }\n    }\n    (filtered_fields.join(\"\"), filtered_types, filtered_columns)\n}\n\n// Helper function for get_filtered_fields\nfn push_filtered_fields(\n    column: &Column,\n    schema: &Type,\n    field: String,\n    filtered_fields: &mut Vec<String>,\n    filtered_columns: &mut Vec<String>,\n    filtered_types: &mut Vec<String>,\n) {\n    let current_field = column.name.clone();\n    let physical_type = get_physical_type(schema, current_field.clone())\n        .unwrap()\n        .to_string();\n    if current_field != field {\n        let repetition = get_repetition(schema, current_field.clone());\n        filtered_fields.push(repetition.unwrap());\n        filtered_fields.push(\" \".to_string());\n\n        filtered_fields.push(physical_type.clone());\n        filtered_fields.push(\" \".to_string());\n\n        filtered_fields.push(current_field.clone());\n        filtered_fields.push(\"; \".to_string());\n    }\n    filtered_types.push(physical_type);\n    filtered_columns.push(current_field);\n}\n\n// Returns a boolean representing whether a string satisfies a given filter\nfn satisfies_string(string_value: Option<&String>, filter: Expr) -> bool {\n    match filter {\n        Expr::BinaryExpr(b) => match b.op {\n            Operator::Eq => Expr::Literal(ScalarValue::Utf8(string_value.cloned())) == *b.right,\n            Operator::NotEq => Expr::Literal(ScalarValue::Utf8(string_value.cloned())) != *b.right,\n            _ => {\n                panic!(\"Unknown satisfies_string operator\");\n            }\n        },\n        Expr::IsNotNull(_) => string_value.is_some(),\n        _ => {\n            panic!(\"Unknown satisfies_string Expr\");\n        }\n    }\n}\n\n// Returns a boolean representing whether an Int64 satisfies a given filter\nfn satisfies_int64(long_value: Option<i64>, filter: Expr) -> bool {\n    match filter {\n        Expr::BinaryExpr(b) => {\n            let filter_value = *b.right;\n            let int_value: i64 = match filter_value {\n                Expr::Literal(ScalarValue::Int64(i)) => i.unwrap(),\n                Expr::Literal(ScalarValue::Int32(i)) => i64::from(i.unwrap()),\n                Expr::Literal(ScalarValue::Float64(i)) => i.unwrap() as i64,\n                Expr::Literal(ScalarValue::TimestampNanosecond(i, None)) => i.unwrap(),\n                Expr::Literal(ScalarValue::Date32(i)) => i64::from(i.unwrap()),\n                // TODO: Add logic to check if the string can be converted to a timestamp\n                Expr::Literal(ScalarValue::Utf8(_)) => return false,\n                _ => {\n                    panic!(\"Unknown ScalarValue type {filter_value}\");\n                }\n            };\n            let filter_value = Expr::Literal(ScalarValue::Int64(Some(int_value)));\n            match b.op {\n                Operator::Eq => Expr::Literal(ScalarValue::Int64(long_value)) == filter_value,\n                Operator::NotEq => Expr::Literal(ScalarValue::Int64(long_value)) != filter_value,\n                Operator::Gt => Expr::Literal(ScalarValue::Int64(long_value)) > filter_value,\n                Operator::Lt => Expr::Literal(ScalarValue::Int64(long_value)) < filter_value,\n                Operator::GtEq => Expr::Literal(ScalarValue::Int64(long_value)) >= filter_value,\n                Operator::LtEq => Expr::Literal(ScalarValue::Int64(long_value)) <= filter_value,\n                _ => {\n                    panic!(\"Unknown satisfies_int64 operator\");\n                }\n            }\n        }\n        Expr::IsNotNull(_) => long_value.is_some(),\n        _ => {\n            panic!(\"Unknown satisfies_int64 Expr\");\n        }\n    }\n}\n\n// Returns a boolean representing whether an Int32 satisfies a given filter\nfn satisfies_int32(long_value: Option<i32>, filter: Expr) -> bool {\n    match filter {\n        Expr::BinaryExpr(b) => {\n            let filter_value = *b.right;\n            let int_value: i32 = match filter_value {\n                Expr::Literal(ScalarValue::Int64(i)) => i.unwrap() as i32,\n                Expr::Literal(ScalarValue::Int32(i)) => i.unwrap(),\n                Expr::Literal(ScalarValue::Float64(i)) => i.unwrap() as i32,\n                _ => {\n                    panic!(\"Unknown ScalarValue type {filter_value}\");\n                }\n            };\n            let filter_value = Expr::Literal(ScalarValue::Int32(Some(int_value)));\n            match b.op {\n                Operator::Eq => Expr::Literal(ScalarValue::Int32(long_value)) == filter_value,\n                Operator::NotEq => Expr::Literal(ScalarValue::Int32(long_value)) != filter_value,\n                Operator::Gt => Expr::Literal(ScalarValue::Int32(long_value)) > filter_value,\n                Operator::Lt => Expr::Literal(ScalarValue::Int32(long_value)) < filter_value,\n                Operator::GtEq => Expr::Literal(ScalarValue::Int32(long_value)) >= filter_value,\n                Operator::LtEq => Expr::Literal(ScalarValue::Int32(long_value)) <= filter_value,\n                _ => {\n                    panic!(\"Unknown satisfies_int32 operator\");\n                }\n            }\n        }\n        Expr::IsNotNull(_) => long_value.is_some(),\n        _ => {\n            panic!(\"Unknown satisfies_int32 Expr\");\n        }\n    }\n}\n\n// Returns a boolean representing whether an Float64 satisfies a given filter\nfn satisfies_float(long_value: Option<f64>, filter: Expr) -> bool {\n    match filter {\n        Expr::BinaryExpr(b) => {\n            let filter_value = *b.right;\n            let float_value: f64 = match filter_value {\n                Expr::Literal(ScalarValue::Int64(i)) => i.unwrap() as f64,\n                Expr::Literal(ScalarValue::Int32(i)) => i.unwrap() as f64,\n                Expr::Literal(ScalarValue::Float64(i)) => i.unwrap(),\n                _ => {\n                    panic!(\"Unknown ScalarValue type {filter_value}\");\n                }\n            };\n            let filter_value = Expr::Literal(ScalarValue::Float64(Some(float_value)));\n            match b.op {\n                Operator::Eq => Expr::Literal(ScalarValue::Float64(long_value)) == filter_value,\n                Operator::NotEq => Expr::Literal(ScalarValue::Float64(long_value)) != filter_value,\n                Operator::Gt => Expr::Literal(ScalarValue::Float64(long_value)) > filter_value,\n                Operator::Lt => Expr::Literal(ScalarValue::Float64(long_value)) < filter_value,\n                Operator::GtEq => Expr::Literal(ScalarValue::Float64(long_value)) >= filter_value,\n                Operator::LtEq => Expr::Literal(ScalarValue::Float64(long_value)) <= filter_value,\n                _ => {\n                    panic!(\"Unknown satisfies_float operator\");\n                }\n            }\n        }\n        Expr::IsNotNull(_) => long_value.is_some(),\n        _ => {\n            panic!(\"Unknown satisfies_float Expr\");\n        }\n    }\n}\n\n// Used to simplify the signature of combine_sets\ntype RowHashSet = HashSet<RowValue>;\ntype RowOptionHashSet = Option<RowHashSet>;\ntype RowTuple = (RowOptionHashSet, RowOptionHashSet);\ntype RowVec = Vec<RowTuple>;\n\n// Given a vector of hashsets to be set as TableScan filters, a vector of tuples representing the\n// tables involved in a join, a vector of tuples representing the columns involved in a join, and\n// a hashset of fact tables in the query; return a hashmap where the key is a tuple of the table\n// and column names, and the value is the hashset representing the INLIST filter specified in the\n// TableScan.\nfn combine_sets(\n    join_values: RowVec,\n    join_tables: Vec<(String, String)>,\n    join_fields: Vec<(String, String)>,\n    fact_tables: HashSet<String>,\n) -> HashMap<(String, String), HashSet<RowValue>> {\n    let mut sets: HashMap<(String, String), HashSet<RowValue>> = HashMap::new();\n    for i in 0..join_values.len() {\n        // Case when we were able to read in both tables involved in the join\n        if let (Some(set1), Some(set2)) = (&join_values[i].0, &join_values[i].1) {\n            // The INLIST vector will be the intersection of both hashsets\n            let set_intersection = set1.intersection(set2);\n            let mut values = HashSet::new();\n            for value in set_intersection {\n                values.insert(value.clone());\n            }\n\n            let current_table = join_tables[i].0.clone();\n            // We only create INLIST filters for fact tables\n            if fact_tables.contains(&current_table) {\n                let current_field = join_fields[i].0.clone();\n                add_to_existing_set(&mut sets, values.clone(), current_table, current_field);\n            }\n\n            let current_table = join_tables[i].1.clone();\n            // We only create INLIST filters for fact tables\n            if fact_tables.contains(&current_table) {\n                let current_field = join_fields[i].1.clone();\n                add_to_existing_set(&mut sets, values.clone(), current_table, current_field);\n            }\n        // Case when we were only able to read in the left table of the join\n        } else if let Some(values) = &join_values[i].0 {\n            let current_table = join_tables[i].0.clone();\n            // We only create INLIST filters for fact tables\n            if fact_tables.contains(&current_table) {\n                let current_field = join_fields[i].0.clone();\n                add_to_existing_set(&mut sets, values.clone(), current_table, current_field);\n            }\n\n            let current_table = join_tables[i].1.clone();\n            // We only create INLIST filters for fact tables\n            if fact_tables.contains(&current_table) {\n                let current_field = join_fields[i].1.clone();\n                add_to_existing_set(&mut sets, values.clone(), current_table, current_field);\n            }\n        // Case when we were only able to read in the right table of the join\n        } else if let Some(values) = &join_values[i].1 {\n            let current_table = join_tables[i].0.clone();\n            // We only create INLIST filters for fact tables\n            if fact_tables.contains(&current_table) {\n                let current_field = join_fields[i].0.clone();\n                add_to_existing_set(&mut sets, values.clone(), current_table, current_field);\n            }\n\n            let current_table = join_tables[i].1.clone();\n            // We only create INLIST filters for fact tables\n            if fact_tables.contains(&current_table) {\n                let current_field = join_fields[i].1.clone();\n                add_to_existing_set(&mut sets, values.clone(), current_table, current_field);\n            }\n        }\n    }\n    sets\n}\n\n// Given a mutable hashmap (the hashmap which will eventually be returned by the `combine_sets`\n// function), a hashset of values, a table name, and a column name; insert the hashset of values\n// into the hashmap, where the key is a tuple of the table and column names.\nfn add_to_existing_set(\n    sets: &mut HashMap<(String, String), HashSet<RowValue>>,\n    values: HashSet<RowValue>,\n    current_table: String,\n    current_field: String,\n) {\n    let existing_set = sets.get(&(current_table.clone(), current_field.clone()));\n    match existing_set {\n        // If the tuple for (current_table, current_field) already exists, then we want to combine\n        // the existing set with the new hashset being inserted; to do this, we take the\n        // intersection of both sets.\n        Some(s) => {\n            let s = s.clone();\n            let v = values.iter().cloned().collect::<HashSet<RowValue>>();\n            let s = s.intersection(&v);\n            let mut set_intersection = HashSet::new();\n            for i in s {\n                set_intersection.insert(i.clone());\n            }\n            sets.insert((current_table, current_field), set_intersection.clone());\n        }\n        // If the tuple for (current_table, current_field) does not already exist as a key in the\n        // hashmap, then simply create it and set the hashset as the value\n        None => {\n            sets.insert((current_table, current_field), values);\n        }\n    }\n}\n\n// Given a LogicalPlan and a hashmap where the key is a tuple containing a table name and column\n// and the value is a hashset of unique row values, parse the LogicalPlan and insert INLIST filters\n// at the TableScan level.\nfn optimize_table_scans(\n    plan: &LogicalPlan,\n    filter_values: HashMap<(String, String), HashSet<RowValue>>,\n) -> Result<Option<LogicalPlan>> {\n    // Replaces existing TableScan with a new TableScan which includes\n    // the new binary expression filter created from reading in the join columns\n    match plan {\n        LogicalPlan::TableScan(t) => {\n            let table_name = t.table_name.to_string();\n            let table_filters: HashMap<(String, String), HashSet<RowValue>> = filter_values\n                .iter()\n                .filter(|(key, _value)| key.0 == table_name)\n                .map(|(key, value)| ((key.0.to_owned(), key.1.to_owned()), value.clone()))\n                .collect();\n            let mut updated_filters = t.filters.clone();\n            for (key, value) in table_filters.iter() {\n                let current_expr =\n                    format_inlist_expr(value.clone(), key.0.to_owned(), key.1.to_owned());\n                if let Some(e) = current_expr {\n                    updated_filters.push(e);\n                }\n            }\n            let scan = LogicalPlan::TableScan(TableScan {\n                table_name: t.table_name.clone(),\n                source: t.source.clone(),\n                projection: t.projection.clone(),\n                projected_schema: t.projected_schema.clone(),\n                filters: updated_filters,\n                fetch: t.fetch,\n            });\n            Ok(Some(scan))\n        }\n        _ => optimize_children(plan, filter_values),\n    }\n}\n\n// Given a hashset of values, a table name, and a column name, return a DataFusion INLIST Expr\nfn format_inlist_expr(\n    value_set: HashSet<RowValue>,\n    join_table: String,\n    join_field: String,\n) -> Option<Expr> {\n    let expr = Box::new(Expr::Column(Column::new(Some(join_table), join_field)));\n    let mut list: Vec<Expr> = vec![];\n\n    // Need to correctly format the ScalarValue type\n    for value in value_set {\n        if let RowValue::String(s) = value {\n            if s.is_some() {\n                let v = Expr::Literal(ScalarValue::Utf8(s));\n                list.push(v);\n            }\n        } else if let RowValue::Int64(l) = value {\n            if l.is_some() {\n                let v = Expr::Literal(ScalarValue::Int64(l));\n                list.push(v);\n            }\n        } else if let RowValue::Int32(i) = value {\n            if i.is_some() {\n                let v = Expr::Literal(ScalarValue::Int32(i));\n                list.push(v);\n            }\n        } else if let RowValue::Double(Some(f)) = value {\n            let v = Expr::Literal(ScalarValue::Float64(Some(f.0)));\n            list.push(v);\n        }\n    }\n\n    if list.is_empty() {\n        None\n    } else {\n        Some(Expr::InList(InList {\n            expr,\n            list,\n            negated: false,\n        }))\n    }\n}\n\n// Given a LogicalPlan and the same hashmap as the `optimize_table_scans` function, correctly\n// iterate through the LogicalPlan nodes. Similar to DataFusion's `optimize_children` function, but\n// recurses on the `optimize_table_scans` function instead.\nfn optimize_children(\n    plan: &LogicalPlan,\n    filter_values: HashMap<(String, String), HashSet<RowValue>>,\n) -> Result<Option<LogicalPlan>> {\n    let new_exprs = plan.expressions();\n    let mut new_inputs = Vec::with_capacity(plan.inputs().len());\n    let mut plan_is_changed = false;\n    for input in plan.inputs() {\n        let new_input = optimize_table_scans(input, filter_values.clone())?;\n        plan_is_changed = plan_is_changed || new_input.is_some();\n        new_inputs.push(new_input.unwrap_or_else(|| input.clone()))\n    }\n    if plan_is_changed {\n        Ok(Some(plan.with_new_exprs(new_exprs, &new_inputs)?))\n    } else {\n        Ok(None)\n    }\n}\n"
  },
  {
    "path": "src/sql/optimizer/join_reorder.rs",
    "content": "//! Join reordering based on the paper \"Improving Join Reordering for Large Scale Distributed Computing\"\n//! https://ieeexplore.ieee.org/document/9378281\n\nuse std::collections::HashSet;\n\nuse datafusion_python::{\n    datafusion_common::{Column, Result},\n    datafusion_expr::{Expr, Join, JoinType, LogicalPlan, LogicalPlanBuilder},\n    datafusion_optimizer::{utils, utils::split_conjunction, OptimizerConfig, OptimizerRule},\n};\nuse log::warn;\n\nuse crate::sql::table::DaskTableSource;\n\npub struct JoinReorder {\n    /// Ratio of the size of the dimension tables to fact tables\n    fact_dimension_ratio: f64,\n    /// Maximum number of fact tables to allow in a join\n    max_fact_tables: usize,\n    /// Whether to preserve user-defined order of unfiltered dimensions\n    preserve_user_order: bool,\n    /// Constant to use when determining the number of rows produced by a\n    /// filtered relation\n    filter_selectivity: f64,\n}\n\nimpl JoinReorder {\n    pub fn new(\n        fact_dimension_ratio: Option<f64>,\n        max_fact_tables: Option<usize>,\n        preserve_user_order: Option<bool>,\n        filter_selectivity: Option<f64>,\n    ) -> Self {\n        Self {\n            // FIXME: Default value for fact_dimension_ratio should be 0.3, not 0.7\n            fact_dimension_ratio: fact_dimension_ratio.unwrap_or(0.7),\n            max_fact_tables: max_fact_tables.unwrap_or(2),\n            preserve_user_order: preserve_user_order.unwrap_or(true),\n            filter_selectivity: filter_selectivity.unwrap_or(1.0),\n        }\n    }\n}\n\nimpl OptimizerRule for JoinReorder {\n    fn name(&self) -> &str {\n        \"join_reorder\"\n    }\n\n    fn try_optimize(\n        &self,\n        plan: &LogicalPlan,\n        _config: &dyn OptimizerConfig,\n    ) -> Result<Option<LogicalPlan>> {\n        let original_plan = plan.clone();\n        // Recurse down first\n        // We want the equivalent of Spark's transformUp here\n        let plan = utils::optimize_children(self, plan, _config)?;\n\n        match &plan {\n            Some(LogicalPlan::Join(join)) if join.join_type == JoinType::Inner => {\n                optimize_join(self, plan.as_ref().unwrap(), join)\n            }\n            Some(plan) => Ok(Some(plan.clone())),\n            None => match &original_plan {\n                LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {\n                    optimize_join(self, &original_plan, join)\n                }\n                _ => Ok(None),\n            },\n        }\n    }\n}\n\nfn optimize_join(\n    rule: &JoinReorder,\n    plan: &LogicalPlan,\n    join: &Join,\n) -> Result<Option<LogicalPlan>> {\n    // FIXME: Check fact/fact join logic\n\n    if !is_supported_join(join) {\n        return Ok(Some(plan.clone()));\n    }\n\n    // Extract the relations and join conditions\n    let (rels, conds) = extract_inner_joins(plan);\n\n    let mut join_conds = HashSet::new();\n    for cond in &conds {\n        match cond {\n            (Expr::Column(l), Expr::Column(r)) => {\n                join_conds.insert((l.clone(), r.clone()));\n            }\n            _ => {\n                return Ok(Some(plan.clone()));\n            }\n        }\n    }\n\n    // Split rels into facts and dims\n    let largest_rel_size = rels.iter().map(|rel| rel.size).max().unwrap() as f64;\n    // Vectors for the fact and dimension tables, respectively\n    let mut facts = vec![];\n    let mut dims = vec![];\n    for rel in &rels {\n        // If the ratio is larger than the fact_dimension_ratio, it is a fact table\n        // Else, it is a dimension table\n        if rel.size as f64 / largest_rel_size > rule.fact_dimension_ratio {\n            facts.push(rel.clone());\n        } else {\n            dims.push(rel.clone());\n        }\n    }\n\n    if facts.is_empty() || dims.is_empty() {\n        return Ok(Some(plan.clone()));\n    }\n    if facts.len() > rule.max_fact_tables {\n        return Ok(Some(plan.clone()));\n    }\n\n    // Get list of dimension tables without a selective predicate\n    let mut unfiltered_dimensions = get_unfiltered_dimensions(&dims);\n    if !rule.preserve_user_order {\n        unfiltered_dimensions.sort_by(|a, b| a.size.cmp(&b.size));\n    }\n\n    // Get list of dimension tables with a selective predicate and sort it\n    let filtered_dimensions = get_filtered_dimensions(&dims);\n    let mut filtered_dimensions: Vec<Relation> = filtered_dimensions\n        .iter()\n        .map(|rel| Relation {\n            plan: rel.plan.clone(),\n            size: (rel.size as f64 * rule.filter_selectivity) as usize,\n        })\n        .collect();\n    filtered_dimensions.sort_by(|a, b| a.size.cmp(&b.size));\n\n    // Merge both the lists of dimensions by giving user order\n    // the preference for tables without a selective predicate,\n    // whereas for tables with selective predicates giving preference\n    // to smaller tables. When comparing the top of both\n    // the lists, if size of the top table in the selective predicate\n    // list is smaller than top of the other list, choose it otherwise\n    // vice-versa.\n    // This algorithm is a greedy approach where smaller\n    // joins with filtered dimension table are preferred for execution\n    // earlier than other Joins to improve Join performance. We try to keep\n    // the user order intact when unsure about reordering to make sure\n    // regressions are minimized.\n    let mut result = vec![];\n    while !filtered_dimensions.is_empty() || !unfiltered_dimensions.is_empty() {\n        if !filtered_dimensions.is_empty() {\n            if !unfiltered_dimensions.is_empty() {\n                if filtered_dimensions[0].size < unfiltered_dimensions[0].size {\n                    result.push(filtered_dimensions.remove(0));\n                } else {\n                    result.push(unfiltered_dimensions.remove(0));\n                }\n            } else {\n                result.push(filtered_dimensions.remove(0));\n            }\n        } else {\n            result.push(unfiltered_dimensions.remove(0));\n        }\n    }\n\n    let dim_plans: Vec<LogicalPlan> = result.iter().map(|rel| rel.plan.clone()).collect();\n\n    let optimized = if facts.len() == 1 {\n        build_join_tree(&facts[0].plan, &dim_plans, &mut join_conds)?\n    } else {\n        // Build one join tree for each fact table\n        let fact_dim_joins = facts\n            .iter()\n            .map(|f| build_join_tree(&f.plan, &dim_plans, &mut join_conds))\n            .collect::<Result<Vec<_>>>()?;\n        // Join the trees together\n        build_join_tree(&fact_dim_joins[0], &fact_dim_joins[1..], &mut join_conds)?\n    };\n\n    if join_conds.is_empty() {\n        Ok(Some(optimized))\n    } else {\n        Ok(Some(plan.clone()))\n    }\n}\n\n/// Represents a Fact or Dimension table, possibly nested in a filter\n#[derive(Clone, Debug)]\nstruct Relation {\n    /// Plan containing the table scan for the fact or dimension table\n    /// May also contain Filter and SubqueryAlias\n    plan: LogicalPlan,\n    /// Estimated size of the underlying table before any filtering is applied\n    size: usize,\n}\n\nimpl Relation {\n    fn new(plan: LogicalPlan) -> Self {\n        let size = get_table_size(&plan);\n        match size {\n            Some(s) => Self { plan, size: s },\n            None => {\n                warn!(\"Table statistics couldn't be obtained; assuming 100 rows\");\n                Self { plan, size: 100 }\n            }\n        }\n    }\n\n    /// Determine if this plan contains any filters\n    fn has_filter(&self) -> bool {\n        has_filter(&self.plan)\n    }\n}\n\nfn has_filter(plan: &LogicalPlan) -> bool {\n    /// We want to ignore \"IsNotNull\" filters that are added for join keys since they exist\n    /// for most dimension tables\n    fn is_real_filter(predicate: &Expr) -> bool {\n        let exprs = split_conjunction(predicate);\n        let x = exprs\n            .iter()\n            .filter(|e| !matches!(e, Expr::IsNotNull(_)))\n            .count();\n        x > 0\n    }\n\n    match plan {\n        LogicalPlan::Filter(filter) => is_real_filter(&filter.predicate),\n        LogicalPlan::TableScan(scan) => scan.filters.iter().any(is_real_filter),\n        _ => plan.inputs().iter().any(|child| has_filter(child)),\n    }\n}\n\n/// Simple Join Constraint: Only INNER Joins are considered\n/// which can be composed of other Joins too. But apart\n/// from the Joins, none of the operator in both the left and\n/// right side of the join should be non-deterministic, or have\n/// output greater than the input to the operator. For instance,\n/// Filter would be allowed operator as it reduces the output\n/// over input, but a project adding extra column will not\n/// be allowed. It is difficult to reason about operators that\n/// add extra to output when dealing with just table sizes, so\n/// instead we only allowed operators from selected set of\n/// operators\nfn is_supported_join(join: &Join) -> bool {\n    // FIXME: Check for deterministic filter expressions\n\n    fn is_supported_rel(plan: &LogicalPlan) -> bool {\n        match plan {\n            LogicalPlan::Join(join) => {\n                join.join_type == JoinType::Inner\n                    // FIXME: Need to support join filters correctly\n                    && join.filter.is_none()\n                    && is_supported_rel(&join.left)\n                    && is_supported_rel(&join.right)\n            }\n            LogicalPlan::Filter(filter) => is_supported_rel(&filter.input),\n            LogicalPlan::SubqueryAlias(sq) => is_supported_rel(&sq.input),\n            LogicalPlan::TableScan(_) => true,\n            _ => false,\n        }\n    }\n\n    is_supported_rel(&LogicalPlan::Join(join.clone()))\n}\n\n/// Extracts items of consecutive inner joins and join conditions\n/// This method works for bushy trees and left/right deep trees\nfn extract_inner_joins(plan: &LogicalPlan) -> (Vec<Relation>, HashSet<(Expr, Expr)>) {\n    fn _extract_inner_joins(\n        plan: &LogicalPlan,\n        rels: &mut Vec<LogicalPlan>,\n        conds: &mut HashSet<(Expr, Expr)>,\n    ) {\n        match plan {\n            LogicalPlan::Join(join)\n                if join.join_type == JoinType::Inner && join.filter.is_none() =>\n            {\n                _extract_inner_joins(&join.left, rels, conds);\n                _extract_inner_joins(&join.right, rels, conds);\n\n                for (l, r) in &join.on {\n                    conds.insert((l.clone(), r.clone()));\n                }\n            }\n            /* FIXME: Need to support join filters correctly\n            LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {\n                _extract_inner_joins(&join.left, rels, conds);\n                _extract_inner_joins(&join.right, rels, conds);\n\n                for (l, r) in &join.on {\n                    conds.insert((l.clone(), r.clone()));\n                }\n\n                // Need to save this info somewhere\n                let join_filter = join.filter.as_ref().unwrap();\n            } */\n            _ => {\n                if find_join(plan).is_some() {\n                    for x in plan.inputs() {\n                        _extract_inner_joins(x, rels, conds);\n                    }\n                } else {\n                    // Leaf node\n                    rels.push(plan.clone())\n                }\n            }\n        }\n    }\n\n    let mut rels = vec![];\n    let mut conds = HashSet::new();\n    _extract_inner_joins(plan, &mut rels, &mut conds);\n    let rels = rels.into_iter().map(Relation::new).collect();\n    (rels, conds)\n}\n\n/// Find first (top-level) join in plan\nfn find_join(plan: &LogicalPlan) -> Option<Join> {\n    match plan {\n        LogicalPlan::Join(join) => Some(join.clone()),\n        other => {\n            if other.inputs().is_empty() {\n                None\n            } else {\n                for input in &other.inputs() {\n                    if let Some(join) = find_join(input) {\n                        return Some(join);\n                    }\n                }\n                None\n            }\n        }\n    }\n}\n\nfn get_unfiltered_dimensions(dims: &[Relation]) -> Vec<Relation> {\n    dims.iter().filter(|t| !t.has_filter()).cloned().collect()\n}\n\nfn get_filtered_dimensions(dims: &[Relation]) -> Vec<Relation> {\n    dims.iter().filter(|t| t.has_filter()).cloned().collect()\n}\n\nfn build_join_tree(\n    fact: &LogicalPlan,\n    dims: &[LogicalPlan],\n    conds: &mut HashSet<(Column, Column)>,\n) -> Result<LogicalPlan> {\n    let mut b = LogicalPlanBuilder::from(fact.clone());\n    for dim in dims {\n        // Find join keys between the fact and this dim\n        let mut join_keys = vec![];\n        for (l, r) in conds.iter() {\n            if (b.schema().index_of_column(l).is_ok() && dim.schema().index_of_column(r).is_ok())\n                || b.schema().index_of_column(r).is_ok() && dim.schema().index_of_column(l).is_ok()\n            {\n                join_keys.push((l.clone(), r.clone()));\n            }\n        }\n        if !join_keys.is_empty() {\n            let left_keys: Vec<Column> = join_keys.iter().map(|(l, _r)| l.clone()).collect();\n            let right_keys: Vec<Column> = join_keys.iter().map(|(_l, r)| r.clone()).collect();\n\n            for key in join_keys {\n                conds.remove(&key);\n            }\n\n            /* FIXME: Build join with join_keys when needed\n            self.join(\n                right: LogicalPlan,\n                join_type: JoinType,\n                join_keys: (Vec<impl Into<Column>>, Vec<impl Into<Column>>),\n                filter: Option<Expr>,\n            ) */\n            b = b.join(dim.clone(), JoinType::Inner, (left_keys, right_keys), None)?;\n        }\n    }\n    b.build()\n}\n\nfn get_table_size(plan: &LogicalPlan) -> Option<usize> {\n    match plan {\n        LogicalPlan::TableScan(scan) => scan\n            .source\n            .as_any()\n            .downcast_ref::<DaskTableSource>()\n            .expect(\"should be a DaskTableSource\")\n            .statistics()\n            .map(|stats| stats.get_row_count() as usize),\n        _ => get_table_size(plan.inputs()[0]),\n    }\n}\n"
  },
  {
    "path": "src/sql/optimizer/utils.rs",
    "content": "// Licensed to the Apache Software Foundation (ASF) under one\n// or more contributor license agreements.  See the NOTICE file\n// distributed with this work for additional information\n// regarding copyright ownership.  The ASF licenses this file\n// to you under the Apache License, Version 2.0 (the\n// \"License\"); you may not use this file except in compliance\n// with the License.  You may obtain a copy of the License at\n//\n//   http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing,\n// software distributed under the License is distributed on an\n// \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n// KIND, either express or implied.  See the License for the\n// specific language governing permissions and limitations\n// under the License.\n\n//! Collection of utility functions that are leveraged by the query optimizer rules\n\nuse std::{\n    collections::{BTreeSet, HashMap},\n    sync::Arc,\n};\n\nuse datafusion_python::{\n    datafusion_common::{Column, DFSchema, DFSchemaRef, Result},\n    datafusion_expr::{\n        and,\n        expr::{Alias, BinaryExpr},\n        expr_rewriter::{replace_col, strip_outer_reference},\n        logical_plan::{Filter, LogicalPlan},\n        Expr,\n        LogicalPlanBuilder,\n        Operator,\n    },\n    datafusion_optimizer::optimizer::{OptimizerConfig, OptimizerRule},\n};\nuse log::{debug, trace};\n\n#[allow(dead_code)]\n/// Convenience rule for writing optimizers: recursively invoke\n/// optimize on plan's children and then return a node of the same\n/// type. Useful for optimizer rules which want to leave the type\n/// of plan unchanged but still apply to the children.\n/// This also handles the case when the `plan` is a [`LogicalPlan::Explain`].\n///\n/// Returning `Ok(None)` indicates that the plan can't be optimized by the `optimizer`.\npub fn optimize_children(\n    optimizer: &impl OptimizerRule,\n    plan: &LogicalPlan,\n    config: &dyn OptimizerConfig,\n) -> Result<Option<LogicalPlan>> {\n    let mut new_inputs = Vec::with_capacity(plan.inputs().len());\n    let mut plan_is_changed = false;\n    for input in plan.inputs() {\n        let new_input = optimizer.try_optimize(input, config)?;\n        plan_is_changed = plan_is_changed || new_input.is_some();\n        new_inputs.push(new_input.unwrap_or_else(|| input.clone()))\n    }\n    if plan_is_changed {\n        Ok(Some(plan.with_new_inputs(&new_inputs)?))\n    } else {\n        Ok(None)\n    }\n}\n\n/// Splits a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]`\n///\n/// See [`split_conjunction_owned`] for more details and an example.\npub fn split_conjunction(expr: &Expr) -> Vec<&Expr> {\n    split_conjunction_impl(expr, vec![])\n}\n\nfn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<&'a Expr> {\n    match expr {\n        Expr::BinaryExpr(BinaryExpr {\n            right,\n            op: Operator::And,\n            left,\n        }) => {\n            let exprs = split_conjunction_impl(left, exprs);\n            split_conjunction_impl(right, exprs)\n        }\n        Expr::Alias(Alias { expr, .. }) => split_conjunction_impl(expr, exprs),\n        other => {\n            exprs.push(other);\n            exprs\n        }\n    }\n}\n\n/// Extract join predicates from the correclated subquery.\n/// The join predicate means that the expression references columns\n/// from both the subquery and outer table or only from the outer table.\n///\n/// Returns join predicates and subquery(extracted).\npub(crate) fn extract_join_filters(maybe_filter: &LogicalPlan) -> Result<(Vec<Expr>, LogicalPlan)> {\n    if let LogicalPlan::Filter(plan_filter) = maybe_filter {\n        let subquery_filter_exprs = split_conjunction(&plan_filter.predicate);\n        let (join_filters, subquery_filters) = find_join_exprs(subquery_filter_exprs)?;\n        // if the subquery still has filter expressions, restore them.\n        let mut plan = LogicalPlanBuilder::from((*plan_filter.input).clone());\n        if let Some(expr) = conjunction(subquery_filters) {\n            plan = plan.filter(expr)?\n        }\n\n        Ok((join_filters, plan.build()?))\n    } else {\n        Ok((vec![], maybe_filter.clone()))\n    }\n}\n\n#[allow(dead_code)]\n/// Splits an owned conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]`\n///\n/// This is often used to \"split\" filter expressions such as `col1 = 5\n/// AND col2 = 10` into [`col1 = 5`, `col2 = 10`];\n///\n/// # Example\n/// ```\n/// # use datafusion_python::datafusion_expr::{col, lit};\n/// # use datafusion_python::datafusion_optimizer::utils::split_conjunction_owned;\n/// // a=1 AND b=2\n/// let expr = col(\"a\").eq(lit(1)).and(col(\"b\").eq(lit(2)));\n///\n/// // [a=1, b=2]\n/// let split = vec![\n///   col(\"a\").eq(lit(1)),\n///   col(\"b\").eq(lit(2)),\n/// ];\n///\n/// // use split_conjunction_owned to split them\n/// assert_eq!(split_conjunction_owned(expr), split);\n/// ```\npub fn split_conjunction_owned(expr: Expr) -> Vec<Expr> {\n    split_binary_owned(expr, Operator::And)\n}\n\n#[allow(dead_code)]\n/// Splits an owned binary operator tree [`Expr`] such as `A <OP> B <OP> C` => `[A, B, C]`\n///\n/// This is often used to \"split\" expressions such as `col1 = 5\n/// AND col2 = 10` into [`col1 = 5`, `col2 = 10`];\n///\n/// # Example\n/// ```\n/// # use datafusion_python::datafusion_expr::{col, lit, Operator};\n/// # use datafusion_python::datafusion_optimizer::utils::split_binary_owned;\n/// # use std::ops::Add;\n/// // a=1 + b=2\n/// let expr = col(\"a\").eq(lit(1)).add(col(\"b\").eq(lit(2)));\n///\n/// // [a=1, b=2]\n/// let split = vec![\n///   col(\"a\").eq(lit(1)),\n///   col(\"b\").eq(lit(2)),\n/// ];\n///\n/// // use split_binary_owned to split them\n/// assert_eq!(split_binary_owned(expr, Operator::Plus), split);\n/// ```\npub fn split_binary_owned(expr: Expr, op: Operator) -> Vec<Expr> {\n    split_binary_owned_impl(expr, op, vec![])\n}\n\n#[allow(dead_code)]\nfn split_binary_owned_impl(expr: Expr, operator: Operator, mut exprs: Vec<Expr>) -> Vec<Expr> {\n    match expr {\n        Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => {\n            let exprs = split_binary_owned_impl(*left, operator, exprs);\n            split_binary_owned_impl(*right, operator, exprs)\n        }\n        Expr::Alias(Alias { expr, .. }) => split_binary_owned_impl(*expr, operator, exprs),\n        other => {\n            exprs.push(other);\n            exprs\n        }\n    }\n}\n\n#[allow(dead_code)]\n/// Splits an binary operator tree [`Expr`] such as `A <OP> B <OP> C` => `[A, B, C]`\n///\n/// See [`split_binary_owned`] for more details and an example.\npub fn split_binary(expr: &Expr, op: Operator) -> Vec<&Expr> {\n    split_binary_impl(expr, op, vec![])\n}\n\n#[allow(dead_code)]\nfn split_binary_impl<'a>(\n    expr: &'a Expr,\n    operator: Operator,\n    mut exprs: Vec<&'a Expr>,\n) -> Vec<&'a Expr> {\n    match expr {\n        Expr::BinaryExpr(BinaryExpr { right, op, left }) if *op == operator => {\n            let exprs = split_binary_impl(left, operator, exprs);\n            split_binary_impl(right, operator, exprs)\n        }\n        Expr::Alias(Alias { expr, .. }) => split_binary_impl(expr, operator, exprs),\n        other => {\n            exprs.push(other);\n            exprs\n        }\n    }\n}\n\n/// Combines an array of filter expressions into a single filter\n/// expression consisting of the input filter expressions joined with\n/// logical AND.\n///\n/// Returns None if the filters array is empty.\n///\n/// # Example\n/// ```\n/// # use datafusion_python::datafusion_expr::{col, lit};\n/// # use datafusion_python::datafusion_optimizer::utils::conjunction;\n/// // a=1 AND b=2\n/// let expr = col(\"a\").eq(lit(1)).and(col(\"b\").eq(lit(2)));\n///\n/// // [a=1, b=2]\n/// let split = vec![\n///   col(\"a\").eq(lit(1)),\n///   col(\"b\").eq(lit(2)),\n/// ];\n///\n/// // use conjunction to join them together with `AND`\n/// assert_eq!(conjunction(split), Some(expr));\n/// ```\npub fn conjunction(filters: impl IntoIterator<Item = Expr>) -> Option<Expr> {\n    filters.into_iter().reduce(|accum, expr| accum.and(expr))\n}\n\n#[allow(dead_code)]\n/// Combines an array of filter expressions into a single filter\n/// expression consisting of the input filter expressions joined with\n/// logical OR.\n///\n/// Returns None if the filters array is empty.\npub fn disjunction(filters: impl IntoIterator<Item = Expr>) -> Option<Expr> {\n    filters.into_iter().reduce(|accum, expr| accum.or(expr))\n}\n\n/// returns a new [LogicalPlan] that wraps `plan` in a [LogicalPlan::Filter] with\n/// its predicate be all `predicates` ANDed.\n#[allow(dead_code)]\npub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result<LogicalPlan> {\n    // reduce filters to a single filter with an AND\n    let predicate = predicates\n        .iter()\n        .skip(1)\n        .fold(predicates[0].clone(), |acc, predicate| {\n            and(acc, (*predicate).to_owned())\n        });\n\n    Ok(LogicalPlan::Filter(Filter::try_new(\n        predicate,\n        Arc::new(plan),\n    )?))\n}\n\n/// Looks for correlating expressions: for example, a binary expression with one field from the subquery, and\n/// one not in the subquery (closed upon from outer scope)\n///\n/// # Arguments\n///\n/// * `exprs` - List of expressions that may or may not be joins\n///\n/// # Return value\n///\n/// Tuple of (expressions containing joins, remaining non-join expressions)\npub fn find_join_exprs(exprs: Vec<&Expr>) -> Result<(Vec<Expr>, Vec<Expr>)> {\n    let mut joins = vec![];\n    let mut others = vec![];\n    for filter in exprs.into_iter() {\n        // If the expression contains correlated predicates, add it to join filters\n        if filter.contains_outer() {\n            if !matches!(filter, Expr::BinaryExpr(BinaryExpr{ left, op: Operator::Eq, right }) if left.eq(right))\n            {\n                joins.push(strip_outer_reference((*filter).clone()));\n            }\n        } else {\n            others.push((*filter).clone());\n        }\n    }\n\n    Ok((joins, others))\n}\n\n/// Returns the first (and only) element in a slice, or an error\n///\n/// # Arguments\n///\n/// * `slice` - The slice to extract from\n///\n/// # Return value\n///\n/// The first element, or an error\npub fn only_or_err<T>(slice: &[T]) -> Result<&T> {\n    match slice {\n        [it] => Ok(it),\n        [] => Err(datafusion_python::datafusion_common::DataFusionError::Plan(\n            \"No items found!\".to_owned(),\n        )),\n        _ => Err(datafusion_python::datafusion_common::DataFusionError::Plan(\n            \"More than one item found!\".to_owned(),\n        )),\n    }\n}\n\n/// merge inputs schema into a single schema.\n#[allow(dead_code)]\npub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema {\n    if inputs.len() == 1 {\n        inputs[0].schema().clone().as_ref().clone()\n    } else {\n        inputs\n            .iter()\n            .map(|input| input.schema())\n            .fold(DFSchema::empty(), |mut lhs, rhs| {\n                lhs.merge(rhs);\n                lhs\n            })\n    }\n}\n\npub(crate) fn collect_subquery_cols(\n    exprs: &[Expr],\n    subquery_schema: DFSchemaRef,\n) -> Result<BTreeSet<Column>> {\n    exprs.iter().try_fold(BTreeSet::new(), |mut cols, expr| {\n        let mut using_cols: Vec<Column> = vec![];\n        for col in expr.to_columns()?.into_iter() {\n            if subquery_schema.has_column(&col) {\n                using_cols.push(col);\n            }\n        }\n\n        cols.extend(using_cols);\n        Result::<_>::Ok(cols)\n    })\n}\n\npub(crate) fn replace_qualified_name(\n    expr: Expr,\n    cols: &BTreeSet<Column>,\n    subquery_alias: &str,\n) -> Result<Expr> {\n    let alias_cols: Vec<Column> = cols\n        .iter()\n        .map(|col| Column::from_qualified_name(format!(\"{}.{}\", subquery_alias, col.name)))\n        .collect();\n    let replace_map: HashMap<&Column, &Column> = cols.iter().zip(alias_cols.iter()).collect();\n\n    replace_col(expr, &replace_map)\n}\n\n#[allow(dead_code)]\n/// Log the plan in debug/tracing mode after some part of the optimizer runs\npub fn log_plan(description: &str, plan: &LogicalPlan) {\n    debug!(\"{description}:\\n{}\\n\", plan.display_indent());\n    trace!(\"{description}::\\n{}\\n\", plan.display_indent_schema());\n}\n\n#[cfg(test)]\nmod tests {\n    use std::collections::HashSet;\n\n    use datafusion_python::{\n        datafusion::arrow::datatypes::DataType,\n        datafusion_common::Column,\n        datafusion_expr::{col, expr::Cast, lit, utils::expr_to_columns},\n    };\n\n    use super::*;\n\n    #[test]\n    fn test_split_conjunction() {\n        let expr = col(\"a\");\n        let result = split_conjunction(&expr);\n        assert_eq!(result, vec![&expr]);\n    }\n\n    #[test]\n    fn test_split_conjunction_two() {\n        let expr = col(\"a\").eq(lit(5)).and(col(\"b\"));\n        let expr1 = col(\"a\").eq(lit(5));\n        let expr2 = col(\"b\");\n\n        let result = split_conjunction(&expr);\n        assert_eq!(result, vec![&expr1, &expr2]);\n    }\n\n    #[test]\n    fn test_split_conjunction_alias() {\n        let expr = col(\"a\").eq(lit(5)).and(col(\"b\").alias(\"the_alias\"));\n        let expr1 = col(\"a\").eq(lit(5));\n        let expr2 = col(\"b\"); // has no alias\n\n        let result = split_conjunction(&expr);\n        assert_eq!(result, vec![&expr1, &expr2]);\n    }\n\n    #[test]\n    fn test_split_conjunction_or() {\n        let expr = col(\"a\").eq(lit(5)).or(col(\"b\"));\n        let result = split_conjunction(&expr);\n        assert_eq!(result, vec![&expr]);\n    }\n\n    #[test]\n    fn test_split_binary_owned() {\n        let expr = col(\"a\");\n        assert_eq!(split_binary_owned(expr.clone(), Operator::And), vec![expr]);\n    }\n\n    #[test]\n    fn test_split_binary_owned_two() {\n        assert_eq!(\n            split_binary_owned(col(\"a\").eq(lit(5)).and(col(\"b\")), Operator::And),\n            vec![col(\"a\").eq(lit(5)), col(\"b\")]\n        );\n    }\n\n    #[test]\n    fn test_split_binary_owned_different_op() {\n        let expr = col(\"a\").eq(lit(5)).or(col(\"b\"));\n        assert_eq!(\n            // expr is connected by OR, but pass in AND\n            split_binary_owned(expr.clone(), Operator::And),\n            vec![expr]\n        );\n    }\n\n    #[test]\n    fn test_split_conjunction_owned() {\n        let expr = col(\"a\");\n        assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]);\n    }\n\n    #[test]\n    fn test_split_conjunction_owned_two() {\n        assert_eq!(\n            split_conjunction_owned(col(\"a\").eq(lit(5)).and(col(\"b\"))),\n            vec![col(\"a\").eq(lit(5)), col(\"b\")]\n        );\n    }\n\n    #[test]\n    fn test_split_conjunction_owned_alias() {\n        assert_eq!(\n            split_conjunction_owned(col(\"a\").eq(lit(5)).and(col(\"b\").alias(\"the_alias\"))),\n            vec![\n                col(\"a\").eq(lit(5)),\n                // no alias on b\n                col(\"b\"),\n            ]\n        );\n    }\n\n    #[test]\n    fn test_conjunction_empty() {\n        assert_eq!(conjunction(vec![]), None);\n    }\n\n    #[test]\n    fn test_conjunction() {\n        // `[A, B, C]`\n        let expr = conjunction(vec![col(\"a\"), col(\"b\"), col(\"c\")]);\n\n        // --> `(A AND B) AND C`\n        assert_eq!(expr, Some(col(\"a\").and(col(\"b\")).and(col(\"c\"))));\n\n        // which is different than `A AND (B AND C)`\n        assert_ne!(expr, Some(col(\"a\").and(col(\"b\").and(col(\"c\")))));\n    }\n\n    #[test]\n    fn test_disjunction_empty() {\n        assert_eq!(disjunction(vec![]), None);\n    }\n\n    #[test]\n    fn test_disjunction() {\n        // `[A, B, C]`\n        let expr = disjunction(vec![col(\"a\"), col(\"b\"), col(\"c\")]);\n\n        // --> `(A OR B) OR C`\n        assert_eq!(expr, Some(col(\"a\").or(col(\"b\")).or(col(\"c\"))));\n\n        // which is different than `A OR (B OR C)`\n        assert_ne!(expr, Some(col(\"a\").or(col(\"b\").or(col(\"c\")))));\n    }\n\n    #[test]\n    fn test_split_conjunction_owned_or() {\n        let expr = col(\"a\").eq(lit(5)).or(col(\"b\"));\n        assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]);\n    }\n\n    #[test]\n    fn test_collect_expr() -> Result<()> {\n        let mut accum: HashSet<Column> = HashSet::new();\n        expr_to_columns(\n            &Expr::Cast(Cast::new(Box::new(col(\"a\")), DataType::Float64)),\n            &mut accum,\n        )?;\n        expr_to_columns(\n            &Expr::Cast(Cast::new(Box::new(col(\"a\")), DataType::Float64)),\n            &mut accum,\n        )?;\n        assert_eq!(1, accum.len());\n        assert!(accum.contains(&Column::from_name(\"a\")));\n        Ok(())\n    }\n}\n"
  },
  {
    "path": "src/sql/optimizer.rs",
    "content": "// Declare optimizer modules\npub mod decorrelate_where_exists;\npub mod decorrelate_where_in;\npub mod dynamic_partition_pruning;\npub mod join_reorder;\npub mod utils;\n\nuse std::sync::Arc;\n\nuse datafusion_python::{\n    datafusion_common::DataFusionError,\n    datafusion_expr::LogicalPlan,\n    datafusion_optimizer::{\n        eliminate_cross_join::EliminateCrossJoin,\n        eliminate_limit::EliminateLimit,\n        eliminate_outer_join::EliminateOuterJoin,\n        eliminate_project::EliminateProjection,\n        filter_null_join_keys::FilterNullJoinKeys,\n        optimizer::{Optimizer, OptimizerRule},\n        push_down_filter::PushDownFilter,\n        push_down_limit::PushDownLimit,\n        push_down_projection::PushDownProjection,\n        rewrite_disjunctive_predicate::RewriteDisjunctivePredicate,\n        scalar_subquery_to_join::ScalarSubqueryToJoin,\n        simplify_expressions::SimplifyExpressions,\n        unwrap_cast_in_comparison::UnwrapCastInComparison,\n        OptimizerContext,\n    },\n};\nuse decorrelate_where_exists::DecorrelateWhereExists;\nuse decorrelate_where_in::DecorrelateWhereIn;\nuse dynamic_partition_pruning::DynamicPartitionPruning;\nuse join_reorder::JoinReorder;\nuse log::{debug, trace};\n\n/// Houses the optimization logic for Dask-SQL. This optimization controls the optimizations\n/// and their ordering in regards to their impact on the underlying `LogicalPlan` instance\npub struct DaskSqlOptimizer {\n    optimizer: Optimizer,\n}\n\nimpl DaskSqlOptimizer {\n    /// Creates a new instance of the DaskSqlOptimizer with all the DataFusion desired\n    /// optimizers as well as any custom `OptimizerRule` trait impls that might be desired.\n    pub fn new(\n        fact_dimension_ratio: Option<f64>,\n        max_fact_tables: Option<usize>,\n        preserve_user_order: Option<bool>,\n        filter_selectivity: Option<f64>,\n    ) -> Self {\n        debug!(\"Creating new instance of DaskSqlOptimizer\");\n\n        let rules: Vec<Arc<dyn OptimizerRule + Sync + Send>> = vec![\n            Arc::new(SimplifyExpressions::new()),\n            Arc::new(UnwrapCastInComparison::new()),\n            // Arc::new(ReplaceDistinctWithAggregate::new()),\n            Arc::new(DecorrelateWhereExists::new()),\n            Arc::new(DecorrelateWhereIn::new()),\n            Arc::new(ScalarSubqueryToJoin::new()),\n            //Arc::new(ExtractEquijoinPredicate::new()),\n\n            // simplify expressions does not simplify expressions in subqueries, so we\n            // run it again after running the optimizations that potentially converted\n            // subqueries to joins\n            Arc::new(SimplifyExpressions::new()),\n            // Arc::new(MergeProjection::new()),\n            Arc::new(RewriteDisjunctivePredicate::new()),\n            // Arc::new(EliminateDuplicatedExpr::new()),\n\n            // TODO: need to handle EmptyRelation for GPU cases\n            // Arc::new(EliminateFilter::new()),\n            Arc::new(EliminateCrossJoin::new()),\n            // Arc::new(CommonSubexprEliminate::new()),\n            Arc::new(EliminateLimit::new()),\n            // Arc::new(PropagateEmptyRelation::new()),\n            Arc::new(FilterNullJoinKeys::default()),\n            Arc::new(EliminateOuterJoin::new()),\n            // Filters can't be pushed down past Limits, we should do PushDownFilter after PushDownLimit\n            Arc::new(PushDownLimit::new()),\n            Arc::new(PushDownFilter::new()),\n            // Arc::new(SingleDistinctToGroupBy::new()),\n            // Dask-SQL specific optimizations\n            Arc::new(JoinReorder::new(\n                fact_dimension_ratio,\n                max_fact_tables,\n                preserve_user_order,\n                filter_selectivity,\n            )),\n            // The previous optimizations added expressions and projections,\n            // that might benefit from the following rules\n            Arc::new(SimplifyExpressions::new()),\n            Arc::new(UnwrapCastInComparison::new()),\n            // Arc::new(CommonSubexprEliminate::new()),\n            Arc::new(PushDownProjection::new()),\n            Arc::new(EliminateProjection::new()),\n            // PushDownProjection can pushdown Projections through Limits, do PushDownLimit again.\n            Arc::new(PushDownLimit::new()),\n        ];\n\n        Self {\n            optimizer: Optimizer::with_rules(rules),\n        }\n    }\n\n    // Create a separate instance of this optimization rule, since we want to ensure that it only\n    // runs one time\n    pub fn dynamic_partition_pruner(fact_dimension_ratio: Option<f64>) -> Self {\n        let rule: Vec<Arc<dyn OptimizerRule + Sync + Send>> = vec![Arc::new(\n            DynamicPartitionPruning::new(fact_dimension_ratio.unwrap_or(0.3)),\n        )];\n\n        Self {\n            optimizer: Optimizer::with_rules(rule),\n        }\n    }\n\n    /// Iterates through the configured `OptimizerRule`(s) to transform the input `LogicalPlan`\n    /// to its final optimized form\n    pub(crate) fn optimize(&self, plan: LogicalPlan) -> Result<LogicalPlan, DataFusionError> {\n        let config = OptimizerContext::new();\n        self.optimizer.optimize(&plan, &config, Self::observe)\n    }\n\n    /// Iterates once through the configured `OptimizerRule`(s) to transform the input `LogicalPlan`\n    /// to its final optimized form\n    pub(crate) fn optimize_once(&self, plan: LogicalPlan) -> Result<LogicalPlan, DataFusionError> {\n        let mut config = OptimizerContext::new();\n        config = OptimizerContext::with_max_passes(config, 1);\n        self.optimizer.optimize(&plan, &config, Self::observe)\n    }\n\n    fn observe(optimized_plan: &LogicalPlan, optimization: &dyn OptimizerRule) {\n        trace!(\n            \"== AFTER APPLYING RULE {} ==\\n{}\\n\",\n            optimization.name(),\n            optimized_plan.display_indent()\n        );\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use std::{any::Any, collections::HashMap, sync::Arc};\n\n    use datafusion_python::{\n        datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef},\n        datafusion_common::{config::ConfigOptions, DataFusionError, Result},\n        datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource},\n        datafusion_sql::{\n            planner::{ContextProvider, SqlToRel},\n            sqlparser::{ast::Statement, parser::Parser},\n            TableReference,\n        },\n    };\n\n    use crate::{dialect::DaskDialect, sql::optimizer::DaskSqlOptimizer};\n\n    #[test]\n    fn subquery_filter_with_cast() -> Result<()> {\n        // regression test for https://github.com/apache/arrow-datafusion/issues/3760\n        let sql = \"SELECT col_int32 FROM test \\\n    WHERE col_int32 > (\\\n      SELECT AVG(col_int32) FROM test \\\n      WHERE col_utf8 BETWEEN '2002-05-08' \\\n        AND (cast('2002-05-08' as date) + interval '5 days')\\\n    )\";\n        let plan = test_sql(sql)?;\n        assert!(format!(\"{:?}\", plan).contains(r#\"<= Date32(\"11820\")\"#));\n        Ok(())\n    }\n\n    fn test_sql(sql: &str) -> Result<LogicalPlan> {\n        // parse the SQL\n        let dialect = DaskDialect {};\n        let ast: Vec<Statement> = Parser::parse_sql(&dialect, sql).unwrap();\n        let statement = &ast[0];\n\n        // create a logical query plan\n        let schema_provider = MySchemaProvider::new();\n        let sql_to_rel = SqlToRel::new(&schema_provider);\n        let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap();\n\n        // optimize the logical plan\n        let optimizer = DaskSqlOptimizer::new(None, None, None, None);\n        optimizer.optimize(plan)\n    }\n\n    struct MySchemaProvider {\n        options: ConfigOptions,\n    }\n\n    impl MySchemaProvider {\n        fn new() -> Self {\n            Self {\n                options: ConfigOptions::default(),\n            }\n        }\n    }\n\n    impl ContextProvider for MySchemaProvider {\n        fn options(&self) -> &ConfigOptions {\n            &self.options\n        }\n\n        fn get_table_provider(\n            &self,\n            name: TableReference,\n        ) -> datafusion_python::datafusion_common::Result<Arc<dyn TableSource>> {\n            let table_name = name.table();\n            if table_name.starts_with(\"test\") {\n                let schema = Schema::new_with_metadata(\n                    vec![\n                        Field::new(\"col_int32\", DataType::Int32, true),\n                        Field::new(\"col_uint32\", DataType::UInt32, true),\n                        Field::new(\"col_utf8\", DataType::Utf8, true),\n                        Field::new(\"col_date32\", DataType::Date32, true),\n                        Field::new(\"col_date64\", DataType::Date64, true),\n                    ],\n                    HashMap::new(),\n                );\n\n                Ok(Arc::new(MyTableSource {\n                    schema: Arc::new(schema),\n                }))\n            } else {\n                Err(DataFusionError::Plan(\"table does not exist\".to_string()))\n            }\n        }\n\n        fn get_function_meta(&self, _name: &str) -> Option<Arc<ScalarUDF>> {\n            None\n        }\n\n        fn get_aggregate_meta(&self, _name: &str) -> Option<Arc<AggregateUDF>> {\n            None\n        }\n\n        fn get_variable_type(&self, _variable_names: &[String]) -> Option<DataType> {\n            None\n        }\n\n        fn get_window_meta(\n            &self,\n            _name: &str,\n        ) -> Option<Arc<datafusion_python::datafusion_expr::WindowUDF>> {\n            None\n        }\n    }\n\n    struct MyTableSource {\n        schema: SchemaRef,\n    }\n\n    impl TableSource for MyTableSource {\n        fn as_any(&self) -> &dyn Any {\n            self\n        }\n\n        fn schema(&self) -> SchemaRef {\n            self.schema.clone()\n        }\n    }\n}\n"
  },
  {
    "path": "src/sql/parser_utils.rs",
    "content": "use datafusion_python::datafusion_sql::sqlparser::{ast::ObjectName, parser::ParserError};\n\npub struct DaskParserUtils;\n\nimpl DaskParserUtils {\n    /// Retrieves the schema and object name from a `ObjectName` instance\n    pub fn elements_from_object_name(\n        obj_name: &ObjectName,\n    ) -> Result<(Option<String>, String), ParserError> {\n        let identities: Vec<String> = obj_name.0.iter().map(|f| f.value.clone()).collect();\n\n        match identities.len() {\n            1 => Ok((None, identities[0].clone())),\n            2 => Ok((Some(identities[0].clone()), identities[1].clone())),\n            _ => Err(ParserError::ParserError(\n                \"TableFactor name only supports 1 or 2 elements\".to_string(),\n            )),\n        }\n    }\n}\n"
  },
  {
    "path": "src/sql/preoptimizer.rs",
    "content": "use std::collections::HashMap;\n\nuse datafusion_python::{\n    datafusion::arrow::datatypes::{DataType, TimeUnit},\n    datafusion_common::{Column, DFField, ScalarValue},\n    datafusion_expr::{logical_plan::Filter, BinaryExpr, Expr, LogicalPlan, Operator},\n};\n\n// Sometimes, DataFusion's optimizer will raise an OptimizationException before we even get the\n// chance to correct it anywhere. In these cases, we can still modify the LogicalPlan as an\n// optimizer rule would, however we have to run it independently, separately of DataFusion's\n// optimization framework. Ideally, these \"pre-optimization\" rules aren't performing any complex\n// logic, but rather \"pre-processing\" the LogicalPlan for the optimizer. For example, the\n// datetime_coercion preoptimizer rule fixes a bug involving Timestamp-Int operations.\n\n// Helper function for datetime_coercion rule, which returns a vector of columns and literals\n// involved in a (possibly nested) BinaryExpr mathematical expression, and the mathematical\n// BinaryExpr itself\nfn extract_columns_and_literals(expr: &Expr) -> Vec<(Vec<Expr>, Expr)> {\n    let mut result = Vec::new();\n    if let Expr::BinaryExpr(b) = expr {\n        let left = *b.left.clone();\n        let right = *b.right.clone();\n        if let Operator::Plus\n        | Operator::Minus\n        | Operator::Multiply\n        | Operator::Divide\n        | Operator::Modulo = &b.op\n        {\n            let mut operands = Vec::new();\n            if let Expr::Column(_) | Expr::Literal(_) = left.clone() {\n                operands.push(left);\n            } else {\n                let vector_of_vectors = extract_columns_and_literals(&left);\n                let mut flattened = Vec::new();\n                for vector in vector_of_vectors {\n                    flattened.extend(vector.0);\n                }\n                operands.append(&mut flattened);\n            }\n\n            if let Expr::Column(_) | Expr::Literal(_) = right.clone() {\n                operands.push(right);\n            } else {\n                let vector_of_vectors = extract_columns_and_literals(&right);\n                let mut flattened = Vec::new();\n                for vector in vector_of_vectors {\n                    flattened.extend(vector.0);\n                }\n                operands.append(&mut flattened);\n            }\n\n            result.push((operands, expr.clone()));\n        } else {\n            if let Expr::BinaryExpr(_) = left {\n                result.append(&mut extract_columns_and_literals(&left));\n            }\n\n            if let Expr::BinaryExpr(_) = right {\n                result.append(&mut extract_columns_and_literals(&right));\n            }\n        }\n    }\n    result\n}\n\n// Helper function for datetime_coercion rule, which uses a LogicalPlan's schema to obtain the\n// datatype of a desired column\nfn find_data_type(column: Column, fields: Vec<DFField>) -> Option<DataType> {\n    for field in fields {\n        if let Some(qualifier) = field.qualifier() {\n            if column.relation.is_some()\n                && qualifier.table() == column.relation.clone().unwrap().table()\n                && field.field().name() == &column.name\n            {\n                return Some(field.field().data_type().clone());\n            }\n        }\n    }\n    None\n}\n\n// Helper function for datetime_coercion rule, which, given a BinaryExpr and a HashMap in which the\n// key represents a Literal and the value represents a Literal to replace the key with, returns the\n// modified BinaryExpr\nfn replace_literals(expr: Expr, replacements: HashMap<Expr, Expr>) -> Expr {\n    match expr {\n        Expr::Literal(l) => {\n            if let Some(new_literal) = replacements.get(&Expr::Literal(l.clone())) {\n                new_literal.clone()\n            } else {\n                Expr::Literal(l)\n            }\n        }\n        Expr::BinaryExpr(b) => {\n            let left = replace_literals(*b.left, replacements.clone());\n            let right = replace_literals(*b.right, replacements);\n            Expr::BinaryExpr(BinaryExpr {\n                left: Box::new(left),\n                op: b.op,\n                right: Box::new(right),\n            })\n        }\n        _ => expr,\n    }\n}\n\n// Helper function for datetime_coercion rule, which, given a BinaryExpr expr and a HashMap in\n// which the key represents a BinaryExpr and the value represents a BinaryExpr to replace the key\n// with, returns the modified expr\nfn replace_binary_exprs(expr: Expr, replacements: HashMap<Expr, Expr>) -> Expr {\n    match expr {\n        Expr::BinaryExpr(b) => {\n            if let Some(new_expr) = replacements.get(&Expr::BinaryExpr(b.clone())) {\n                new_expr.clone()\n            } else {\n                let left = replace_binary_exprs(*b.left, replacements.clone());\n                let right = replace_binary_exprs(*b.right, replacements);\n                Expr::BinaryExpr(BinaryExpr {\n                    left: Box::new(left),\n                    op: b.op,\n                    right: Box::new(right),\n                })\n            }\n        }\n        _ => expr,\n    }\n}\n\n// Preoptimization rule which detects when the user is trying to perform a binary operation on a\n// datetime and an integer, then converts the integer to a IntervalMonthDayNano. For example, if we\n// have a date_col + 5, we assume that we are adding 5 days to the date_col\npub fn datetime_coercion(plan: &LogicalPlan) -> Option<LogicalPlan> {\n    match plan {\n        LogicalPlan::Filter(f) => {\n            let filter_expr = f.predicate.clone();\n            let columns_and_literals = extract_columns_and_literals(&filter_expr);\n\n            let mut days_to_nanoseconds: Vec<(Expr, HashMap<Expr, Expr>)> = Vec::new();\n            for vector in columns_and_literals.iter() {\n                // Detect whether a timestamp is involved in the operation\n                let mut is_timestamp_operation = false;\n                for item in vector.0.iter() {\n                    if let Expr::Column(column) = item {\n                        if let Some(DataType::Timestamp(TimeUnit::Nanosecond, _)) =\n                            find_data_type(column.clone(), plan.schema().fields().clone())\n                        {\n                            is_timestamp_operation = true;\n                        }\n                    }\n                }\n\n                // Convert an integer to an IntervalMonthDayNano\n                if is_timestamp_operation {\n                    let mut find_replace = HashMap::new();\n                    for item in vector.0.iter() {\n                        if let Expr::Literal(ScalarValue::Int64(i)) = item {\n                            let ns = i.unwrap() as i128 * 18446744073709552000;\n\n                            find_replace.insert(\n                                Expr::Literal(ScalarValue::Int64(*i)),\n                                Expr::Literal(ScalarValue::IntervalMonthDayNano(Some(ns))),\n                            );\n                        }\n                    }\n                    days_to_nanoseconds.push((vector.1.clone(), find_replace));\n                }\n            }\n\n            let mut binary_exprs = HashMap::new();\n            for replacements in days_to_nanoseconds.iter() {\n                binary_exprs.insert(\n                    replacements.0.clone(),\n                    replace_literals(replacements.0.clone(), replacements.1.clone()),\n                );\n            }\n            let new_filter = replace_binary_exprs(filter_expr, binary_exprs);\n            Some(LogicalPlan::Filter(\n                Filter::try_new(new_filter, f.input.clone()).unwrap(),\n            ))\n        }\n        _ => optimize_children(plan.clone()),\n    }\n}\n\n// Function used to iterate through a LogicalPlan and update it accordingly\nfn optimize_children(existing_plan: LogicalPlan) -> Option<LogicalPlan> {\n    let plan = existing_plan.clone();\n    let new_exprs = plan.expressions();\n    let mut new_inputs = Vec::with_capacity(plan.inputs().len());\n    let mut plan_is_changed = false;\n    for input in plan.inputs() {\n        // Since datetime_coercion is the only preoptimizer rule that we have at the moment, we\n        // hardcode it here. If additional preoptimizer rules are added in the future, this can be\n        // modified\n        let new_input = datetime_coercion(input);\n        plan_is_changed = plan_is_changed || new_input.is_some();\n        new_inputs.push(new_input.unwrap_or_else(|| input.clone()))\n    }\n    if plan_is_changed {\n        Some(plan.with_new_exprs(new_exprs, &new_inputs).ok()?)\n    } else {\n        Some(existing_plan)\n    }\n}\n"
  },
  {
    "path": "src/sql/schema.rs",
    "content": "use std::collections::HashMap;\n\nuse ::std::sync::{Arc, Mutex};\nuse pyo3::prelude::*;\n\nuse super::types::PyDataType;\nuse crate::sql::{function::DaskFunction, table};\n\n#[pyclass(name = \"DaskSchema\", module = \"dask_sql\", subclass)]\n#[derive(Debug, Clone)]\npub struct DaskSchema {\n    #[pyo3(get, set)]\n    pub(crate) name: String,\n    pub(crate) tables: HashMap<String, table::DaskTable>,\n    pub(crate) functions: HashMap<String, Arc<Mutex<DaskFunction>>>,\n}\n\n#[pymethods]\nimpl DaskSchema {\n    #[new]\n    pub fn new(schema_name: &str) -> Self {\n        Self {\n            name: schema_name.to_owned(),\n            tables: HashMap::new(),\n            functions: HashMap::new(),\n        }\n    }\n\n    pub fn add_table(&mut self, table: table::DaskTable) {\n        self.tables.insert(table.table_name.clone(), table);\n    }\n\n    pub fn add_or_overload_function(\n        &mut self,\n        name: String,\n        input_types: Vec<PyDataType>,\n        return_type: PyDataType,\n        aggregation: bool,\n    ) {\n        self.functions\n            .entry(name.clone())\n            .and_modify(|e| {\n                (*e).lock()\n                    .unwrap()\n                    .add_type_mapping(input_types.clone(), return_type.clone());\n            })\n            .or_insert_with(|| {\n                Arc::new(Mutex::new(DaskFunction::new(\n                    name,\n                    input_types,\n                    return_type,\n                    aggregation,\n                )))\n            });\n    }\n}\n"
  },
  {
    "path": "src/sql/statement.rs",
    "content": "use pyo3::prelude::*;\n\nuse crate::parser::DaskStatement;\n\n#[pyclass(name = \"Statement\", module = \"dask_sql\", subclass)]\n#[derive(Debug, Clone)]\npub struct PyStatement {\n    pub statement: DaskStatement,\n}\n\nimpl From<PyStatement> for DaskStatement {\n    fn from(statement: PyStatement) -> DaskStatement {\n        statement.statement\n    }\n}\n\nimpl From<DaskStatement> for PyStatement {\n    fn from(statement: DaskStatement) -> PyStatement {\n        PyStatement { statement }\n    }\n}\n\nimpl PyStatement {\n    pub fn new(statement: DaskStatement) -> Self {\n        Self { statement }\n    }\n}\n"
  },
  {
    "path": "src/sql/table.rs",
    "content": "use std::{any::Any, sync::Arc};\n\nuse async_trait::async_trait;\nuse datafusion_python::{\n    datafusion::arrow::datatypes::{DataType, Fields, SchemaRef},\n    datafusion_common::DFField,\n    datafusion_expr::{Expr, LogicalPlan, TableProviderFilterPushDown, TableSource},\n    datafusion_optimizer::utils::split_conjunction,\n    datafusion_sql::TableReference,\n};\nuse pyo3::prelude::*;\n\nuse super::logical::{create_table::CreateTablePlanNode, predict_model::PredictModelPlanNode};\nuse crate::{\n    error::DaskPlannerError,\n    sql::{\n        logical,\n        types::{\n            rel_data_type::RelDataType,\n            rel_data_type_field::RelDataTypeField,\n            DaskTypeMap,\n            SqlTypeName,\n        },\n    },\n};\n\n/// DaskTable wrapper that is compatible with DataFusion logical query plans\npub struct DaskTableSource {\n    schema: SchemaRef,\n    statistics: Option<DaskStatistics>,\n    filepath: Option<String>,\n}\n\nimpl DaskTableSource {\n    /// Initialize a new `EmptyTable` from a schema\n    pub fn new(\n        schema: SchemaRef,\n        statistics: Option<DaskStatistics>,\n        filepath: Option<String>,\n    ) -> Self {\n        Self {\n            schema,\n            statistics,\n            filepath,\n        }\n    }\n\n    /// Access optional statistics associated with this table source\n    pub fn statistics(&self) -> Option<&DaskStatistics> {\n        self.statistics.as_ref()\n    }\n\n    /// Access optional filepath associated with this table source\n    pub fn filepath(&self) -> Option<&String> {\n        self.filepath.as_ref()\n    }\n}\n\n/// Implement TableSource, used in the logical query plan and in logical query optimizations\n#[async_trait]\nimpl TableSource for DaskTableSource {\n    fn as_any(&self) -> &dyn Any {\n        self\n    }\n\n    fn schema(&self) -> SchemaRef {\n        self.schema.clone()\n    }\n\n    fn supports_filter_pushdown(\n        &self,\n        filter: &Expr,\n    ) -> datafusion_python::datafusion_common::Result<TableProviderFilterPushDown> {\n        let filters = split_conjunction(filter);\n        if filters.iter().all(|f| is_supported_push_down_expr(f)) {\n            // Push down filters to the tablescan operation if all are supported\n            Ok(TableProviderFilterPushDown::Exact)\n        } else if filters.iter().any(|f| is_supported_push_down_expr(f)) {\n            // Partially apply the filter in the TableScan but retain\n            // the Filter operator in the plan as well\n            Ok(TableProviderFilterPushDown::Inexact)\n        } else {\n            Ok(TableProviderFilterPushDown::Unsupported)\n        }\n    }\n}\n\nfn is_supported_push_down_expr(_expr: &Expr) -> bool {\n    // For now we support all kinds of expr's at this level\n    true\n}\n\n#[pyclass(name = \"DaskStatistics\", module = \"dask_sql\", subclass)]\n#[derive(Debug, Clone)]\npub struct DaskStatistics {\n    row_count: f64,\n}\n\n#[pymethods]\nimpl DaskStatistics {\n    #[new]\n    pub fn new(row_count: f64) -> Self {\n        Self { row_count }\n    }\n\n    #[pyo3(name = \"getRowCount\")]\n    pub fn get_row_count(&self) -> f64 {\n        self.row_count\n    }\n}\n\n#[pyclass(name = \"DaskTable\", module = \"dask_sql\", subclass)]\n#[derive(Debug, Clone)]\npub struct DaskTable {\n    pub(crate) schema_name: Option<String>,\n    pub(crate) table_name: String,\n    pub(crate) statistics: DaskStatistics,\n    pub(crate) columns: Vec<(String, DaskTypeMap)>,\n    pub(crate) filepath: Option<String>,\n}\n\n#[pymethods]\nimpl DaskTable {\n    #[new]\n    pub fn new(\n        schema_name: &str,\n        table_name: &str,\n        row_count: f64,\n        columns: Option<Vec<(String, DaskTypeMap)>>,\n        filepath: Option<String>,\n    ) -> Self {\n        Self {\n            schema_name: Some(schema_name.to_owned()),\n            table_name: table_name.to_owned(),\n            statistics: DaskStatistics::new(row_count),\n            columns: columns.unwrap_or_default(),\n            filepath,\n        }\n    }\n\n    // TODO: Really wish we could accept a SqlTypeName instance here instead of a String for `column_type` ....\n    #[pyo3(name = \"add_column\")]\n    pub fn add_column(&mut self, column_name: &str, type_map: DaskTypeMap) {\n        self.columns.push((column_name.to_owned(), type_map));\n    }\n\n    #[pyo3(name = \"getSchema\")]\n    pub fn get_schema(&self) -> PyResult<Option<String>> {\n        Ok(self.schema_name.clone())\n    }\n\n    #[pyo3(name = \"getTableName\")]\n    pub fn get_table_name(&self) -> PyResult<String> {\n        Ok(self.table_name.clone())\n    }\n\n    #[pyo3(name = \"getQualifiedName\")]\n    pub fn qualified_name(&self, plan: logical::PyLogicalPlan) -> Vec<String> {\n        let mut qualified_name = match &self.schema_name {\n            Some(schema_name) => vec![schema_name.clone()],\n            None => vec![],\n        };\n\n        match plan.original_plan {\n            LogicalPlan::TableScan(table_scan) => {\n                qualified_name.push(table_scan.table_name.to_string());\n            }\n            _ => {\n                qualified_name.push(self.table_name.clone());\n            }\n        }\n\n        qualified_name\n    }\n\n    #[pyo3(name = \"getRowType\")]\n    pub fn row_type(&self) -> RelDataType {\n        let mut fields: Vec<RelDataTypeField> = Vec::new();\n        for (name, data_type) in &self.columns {\n            fields.push(RelDataTypeField::new(name.as_str(), data_type.clone(), 255));\n        }\n        RelDataType::new(false, fields)\n    }\n}\n\n/// Traverses the logical plan to locate the Table associated with the query\npub(crate) fn table_from_logical_plan(\n    plan: &LogicalPlan,\n) -> Result<Option<DaskTable>, DaskPlannerError> {\n    match plan {\n        LogicalPlan::Projection(projection) => table_from_logical_plan(&projection.input),\n        LogicalPlan::Filter(filter) => table_from_logical_plan(&filter.input),\n        LogicalPlan::TableScan(table_scan) => {\n            // Get the TableProvider for this Table instance\n            let tbl_provider: Arc<dyn TableSource> = table_scan.source.clone();\n            let tbl_schema: SchemaRef = tbl_provider.schema();\n            let fields: &Fields = tbl_schema.fields();\n\n            let mut cols: Vec<(String, DaskTypeMap)> = Vec::new();\n            for field in fields {\n                let data_type: &DataType = field.data_type();\n                cols.push((\n                    String::from(field.name()),\n                    DaskTypeMap::from(\n                        SqlTypeName::from_arrow(data_type)?,\n                        data_type.clone().into(),\n                    ),\n                ));\n            }\n\n            let table_ref: TableReference = table_scan.table_name.clone();\n            let (schema, tbl) = match table_ref {\n                TableReference::Bare { table } => (\"\".to_string(), table),\n                TableReference::Partial { schema, table } => (schema.to_string(), table),\n                TableReference::Full {\n                    catalog: _,\n                    schema,\n                    table,\n                } => (schema.to_string(), table),\n            };\n\n            Ok(Some(DaskTable {\n                schema_name: Some(schema),\n                table_name: String::from(tbl),\n                statistics: DaskStatistics { row_count: 0.0 },\n                columns: cols,\n                filepath: None,\n            }))\n        }\n        LogicalPlan::Join(join) => {\n            // TODO: Don't always hardcode the left\n            table_from_logical_plan(&join.left)\n        }\n        LogicalPlan::Aggregate(agg) => table_from_logical_plan(&agg.input),\n        LogicalPlan::SubqueryAlias(alias) => table_from_logical_plan(&alias.input),\n        LogicalPlan::EmptyRelation(empty_relation) => {\n            let fields: &Vec<DFField> = empty_relation.schema.fields();\n\n            let mut cols: Vec<(String, DaskTypeMap)> = Vec::new();\n            for field in fields {\n                let data_type: &DataType = field.data_type();\n                cols.push((\n                    String::from(field.name()),\n                    DaskTypeMap::from(\n                        SqlTypeName::from_arrow(data_type)?,\n                        data_type.clone().into(),\n                    ),\n                ));\n            }\n\n            Ok(Some(DaskTable {\n                schema_name: Some(String::from(\"EmptySchema\")),\n                table_name: String::from(\"EmptyRelation\"),\n                statistics: DaskStatistics { row_count: 0.0 },\n                columns: cols,\n                filepath: None,\n            }))\n        }\n        LogicalPlan::Extension(ex) => {\n            let node = ex.node.as_any();\n            if let Some(e) = node.downcast_ref::<CreateTablePlanNode>() {\n                Ok(Some(DaskTable {\n                    schema_name: e.schema_name.clone(),\n                    table_name: e.table_name.clone(),\n                    statistics: DaskStatistics { row_count: 0.0 },\n                    columns: vec![],\n                    filepath: None,\n                }))\n            } else if let Some(e) = node.downcast_ref::<PredictModelPlanNode>() {\n                Ok(Some(DaskTable {\n                    schema_name: e.schema_name.clone(),\n                    table_name: e.model_name.clone(),\n                    statistics: DaskStatistics { row_count: 0.0 },\n                    columns: vec![],\n                    filepath: None,\n                }))\n            } else {\n                Err(DaskPlannerError::Internal(format!(\n                    \"table_from_logical_plan: unimplemented LogicalPlan type {plan:?} encountered\"\n                )))\n            }\n        }\n        _ => Err(DaskPlannerError::Internal(format!(\n            \"table_from_logical_plan: unimplemented LogicalPlan type {plan:?} encountered\"\n        ))),\n    }\n}\n"
  },
  {
    "path": "src/sql/types/rel_data_type.rs",
    "content": "use std::collections::HashMap;\n\nuse pyo3::prelude::*;\n\nuse crate::sql::{exceptions::py_runtime_err, types::rel_data_type_field::RelDataTypeField};\n\nconst PRECISION_NOT_SPECIFIED: i32 = i32::MIN;\nconst SCALE_NOT_SPECIFIED: i32 = -1;\n\n/// RelDataType represents the type of a scalar expression or entire row returned from a relational expression.\n#[pyclass(name = \"RelDataType\", module = \"dask_sql\", subclass)]\n#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]\npub struct RelDataType {\n    nullable: bool,\n    field_list: Vec<RelDataTypeField>,\n}\n\n/// RelDataType represents the type of a scalar expression or entire row returned from a relational expression.\n#[pymethods]\nimpl RelDataType {\n    #[new]\n    pub fn new(nullable: bool, fields: Vec<RelDataTypeField>) -> Self {\n        Self {\n            nullable,\n            field_list: fields,\n        }\n    }\n\n    /// Looks up a field by name.\n    ///\n    /// # Arguments\n    ///\n    /// * `field_name` - A String containing the name of the field to find\n    /// * `case_sensitive` - True if column name matching should be case sensitive and false otherwise\n    #[pyo3(name = \"getField\")]\n    pub fn field(&self, field_name: &str, case_sensitive: bool) -> PyResult<RelDataTypeField> {\n        let field_map: HashMap<String, RelDataTypeField> = self.field_map();\n        if case_sensitive && !field_map.is_empty() {\n            Ok(field_map.get(field_name).unwrap().clone())\n        } else {\n            for field in &self.field_list {\n                if (case_sensitive && field.name().eq(field_name))\n                    || (!case_sensitive && field.name().eq_ignore_ascii_case(field_name))\n                {\n                    return Ok(field.clone());\n                }\n            }\n\n            // TODO: Throw a proper error here\n            Err(py_runtime_err(format!(\n                \"Unable to find RelDataTypeField with name {field_name:?} in the RelDataType field_list\"\n            )))\n        }\n    }\n\n    /// Returns a map from field names to fields.\n    ///\n    /// # Notes\n    ///\n    /// * If several fields have the same name, the map contains the first.\n    #[pyo3(name = \"getFieldMap\")]\n    pub fn field_map(&self) -> HashMap<String, RelDataTypeField> {\n        let mut fields: HashMap<String, RelDataTypeField> = HashMap::new();\n        for field in &self.field_list {\n            fields.insert(String::from(field.name()), field.clone());\n        }\n        fields\n    }\n\n    /// Gets the fields in a struct type. The field count is equal to the size of the returned list.\n    #[pyo3(name = \"getFieldList\")]\n    pub fn field_list(&self) -> Vec<RelDataTypeField> {\n        self.field_list.clone()\n    }\n\n    /// Returns the names of all of the columns in a given DaskTable\n    #[pyo3(name = \"getFieldNames\")]\n    pub fn field_names(&self) -> Vec<String> {\n        let mut field_names: Vec<String> = Vec::new();\n        for field in &self.field_list {\n            field_names.push(field.qualified_name());\n        }\n        field_names\n    }\n\n    /// Returns the number of fields in a struct type.\n    #[pyo3(name = \"getFieldCount\")]\n    pub fn field_count(&self) -> usize {\n        self.field_list.len()\n    }\n\n    #[pyo3(name = \"isStruct\")]\n    pub fn is_struct(&self) -> bool {\n        !self.field_list.is_empty()\n    }\n\n    /// Queries whether this type allows null values.\n    #[pyo3(name = \"isNullable\")]\n    pub fn is_nullable(&self) -> bool {\n        self.nullable\n    }\n\n    #[pyo3(name = \"getPrecision\")]\n    pub fn precision(&self) -> i32 {\n        PRECISION_NOT_SPECIFIED\n    }\n\n    #[pyo3(name = \"getScale\")]\n    pub fn scale(&self) -> i32 {\n        SCALE_NOT_SPECIFIED\n    }\n}\n"
  },
  {
    "path": "src/sql/types/rel_data_type_field.rs",
    "content": "use std::fmt;\n\nuse datafusion_python::{\n    datafusion_common::{DFField, DFSchema},\n    datafusion_sql::TableReference,\n};\nuse pyo3::prelude::*;\n\nuse crate::{\n    error::Result,\n    sql::types::{DaskTypeMap, SqlTypeName},\n};\n\n/// RelDataTypeField represents the definition of a field in a structured RelDataType.\n#[pyclass(name = \"RelDataTypeField\", module = \"dask_sql\", subclass)]\n#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]\npub struct RelDataTypeField {\n    qualifier: Option<String>,\n    name: String,\n    data_type: DaskTypeMap,\n    index: usize,\n}\n\n// Functions that should not be presented to Python are placed here\nimpl RelDataTypeField {\n    pub fn from(field: &DFField, schema: &DFSchema) -> Result<RelDataTypeField> {\n        let qualifier: Option<&TableReference> = field.qualifier();\n        Ok(RelDataTypeField {\n            qualifier: qualifier.map(|qualifier| qualifier.to_string()),\n            name: field.name().clone(),\n            data_type: DaskTypeMap {\n                sql_type: SqlTypeName::from_arrow(field.data_type())?,\n                data_type: field.data_type().clone().into(),\n            },\n            index: schema\n                .index_of_column_by_name(qualifier, field.name())?\n                .unwrap(),\n        })\n    }\n}\n\n#[pymethods]\nimpl RelDataTypeField {\n    #[new]\n    pub fn new(name: &str, type_map: DaskTypeMap, index: usize) -> Self {\n        Self {\n            qualifier: None,\n            name: name.to_owned(),\n            data_type: type_map,\n            index,\n        }\n    }\n\n    #[pyo3(name = \"getQualifier\")]\n    pub fn qualifier(&self) -> Option<String> {\n        self.qualifier.clone()\n    }\n\n    #[pyo3(name = \"getName\")]\n    pub fn name(&self) -> &str {\n        &self.name\n    }\n\n    #[pyo3(name = \"getQualifiedName\")]\n    pub fn qualified_name(&self) -> String {\n        match &self.qualifier() {\n            Some(qualifier) => format!(\"{}.{}\", &qualifier, self.name()),\n            None => self.name().to_string(),\n        }\n    }\n\n    #[pyo3(name = \"getIndex\")]\n    pub fn index(&self) -> usize {\n        self.index\n    }\n\n    #[pyo3(name = \"getType\")]\n    pub fn data_type(&self) -> DaskTypeMap {\n        self.data_type.clone()\n    }\n\n    /// Since this logic is being ported from Java getKey is synonymous with getName.\n    /// Alas it is used in certain places so it is implemented here to allow other\n    /// places in the code base to not have to change.\n    #[pyo3(name = \"getKey\")]\n    pub fn get_key(&self) -> &str {\n        self.name()\n    }\n\n    /// Since this logic is being ported from Java getValue is synonymous with getType.\n    /// Alas it is used in certain places so it is implemented here to allow other\n    /// places in the code base to not have to change.\n    #[pyo3(name = \"getValue\")]\n    pub fn get_value(&self) -> DaskTypeMap {\n        self.data_type()\n    }\n\n    #[pyo3(name = \"setValue\")]\n    pub fn set_value(&mut self, data_type: DaskTypeMap) {\n        self.data_type = data_type\n    }\n\n    // TODO: Uncomment after implementing in RelDataType\n    // #[pyo3(name = \"isDynamicStar\")]\n    // pub fn is_dynamic_star(&self) -> bool {\n    //     self.data_type.getSqlTypeName() == SqlTypeName.DYNAMIC_STAR\n    // }\n}\n\nimpl fmt::Display for RelDataTypeField {\n    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {\n        fmt.write_str(\"Field: \")?;\n        fmt.write_str(&self.name)?;\n        fmt.write_str(\" - Index: \")?;\n        fmt.write_str(&self.index.to_string())?;\n        // TODO: Uncomment this after implementing the Display trait in RelDataType\n        // fmt.write_str(\" - DataType: \")?;\n        // fmt.write_str(self.data_type.to_string())?;\n        Ok(())\n    }\n}\n"
  },
  {
    "path": "src/sql/types.rs",
    "content": "pub mod rel_data_type;\npub mod rel_data_type_field;\n\nuse std::sync::Arc;\n\nuse datafusion_python::{\n    datafusion::arrow::datatypes::{DataType, IntervalUnit, TimeUnit},\n    datafusion_sql::sqlparser::{ast::DataType as SQLType, parser::Parser, tokenizer::Tokenizer},\n};\nuse pyo3::{prelude::*, types::PyDict};\n\nuse crate::{dialect::DaskDialect, error::DaskPlannerError, sql::exceptions::py_type_err};\n\n#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]\n#[pyclass(name = \"RexType\", module = \"dask_sql\")]\npub enum RexType {\n    Alias,\n    Literal,\n    Call,\n    Reference,\n    ScalarSubquery,\n    Other,\n}\n\n#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]\n#[pyclass(name = \"DaskTypeMap\", module = \"dask_sql\", subclass)]\n/// Represents a Python Data Type. This is needed instead of simple\n/// Enum instances because PyO3 can only support unit variants as\n/// of version 0.16 which means Enums like `DataType::TIMESTAMP_WITH_LOCAL_TIME_ZONE`\n/// which generally hold `unit` and `tz` information are unable to\n/// do that so data is lost. This struct aims to solve that issue\n/// by taking the type Enum from Python and some optional extra\n/// parameters that can be used to properly create those DataType\n/// instances in Rust.\npub struct DaskTypeMap {\n    sql_type: SqlTypeName,\n    data_type: PyDataType,\n}\n\n/// Functions not exposed to Python\nimpl DaskTypeMap {\n    pub fn from(sql_type: SqlTypeName, data_type: PyDataType) -> Self {\n        DaskTypeMap {\n            sql_type,\n            data_type,\n        }\n    }\n}\n\n#[pymethods]\nimpl DaskTypeMap {\n    #[new]\n    #[pyo3(signature = (sql_type, **py_kwargs))]\n    fn new(sql_type: SqlTypeName, py_kwargs: Option<&PyDict>) -> PyResult<Self> {\n        let d_type: DataType = match sql_type {\n            SqlTypeName::TIMESTAMP_WITH_LOCAL_TIME_ZONE => {\n                let (unit, tz) = match py_kwargs {\n                    Some(dict) => {\n                        let tz: Option<Arc<str>> = match dict.get_item(\"tz\") {\n                            Some(e) => {\n                                let res: PyResult<String> = e.extract();\n                                Some(Arc::from(<std::string::String as AsRef<str>>::as_ref(\n                                    &res.unwrap(),\n                                )))\n                            }\n                            None => None,\n                        };\n                        let unit: TimeUnit = match dict.get_item(\"unit\") {\n                            Some(e) => {\n                                let res: PyResult<&str> = e.extract();\n                                match res.unwrap() {\n                                    \"Second\" => TimeUnit::Second,\n                                    \"Millisecond\" => TimeUnit::Millisecond,\n                                    \"Microsecond\" => TimeUnit::Microsecond,\n                                    \"Nanosecond\" => TimeUnit::Nanosecond,\n                                    _ => TimeUnit::Nanosecond,\n                                }\n                            }\n                            // Default to Nanosecond which is common if not present\n                            None => TimeUnit::Nanosecond,\n                        };\n                        (unit, tz)\n                    }\n                    // Default to Nanosecond and None for tz which is common if not present\n                    None => (TimeUnit::Nanosecond, None),\n                };\n                DataType::Timestamp(unit, tz)\n            }\n            SqlTypeName::TIMESTAMP => {\n                let (unit, tz) = match py_kwargs {\n                    Some(dict) => {\n                        let tz: Option<Arc<str>> = match dict.get_item(\"tz\") {\n                            Some(e) => {\n                                let res: PyResult<String> = e.extract();\n                                Some(Arc::from(<std::string::String as AsRef<str>>::as_ref(\n                                    &res.unwrap(),\n                                )))\n                            }\n                            None => None,\n                        };\n                        let unit: TimeUnit = match dict.get_item(\"unit\") {\n                            Some(e) => {\n                                let res: PyResult<&str> = e.extract();\n                                match res.unwrap() {\n                                    \"Second\" => TimeUnit::Second,\n                                    \"Millisecond\" => TimeUnit::Millisecond,\n                                    \"Microsecond\" => TimeUnit::Microsecond,\n                                    \"Nanosecond\" => TimeUnit::Nanosecond,\n                                    _ => TimeUnit::Nanosecond,\n                                }\n                            }\n                            // Default to Nanosecond which is common if not present\n                            None => TimeUnit::Nanosecond,\n                        };\n                        (unit, tz)\n                    }\n                    // Default to Nanosecond and None for tz which is common if not present\n                    None => (TimeUnit::Nanosecond, None),\n                };\n                DataType::Timestamp(unit, tz)\n            }\n            SqlTypeName::DECIMAL => {\n                let (precision, scale) = match py_kwargs {\n                    Some(dict) => {\n                        let precision: u8 = match dict.get_item(\"precision\") {\n                            Some(e) => {\n                                let res: PyResult<u8> = e.extract();\n                                res.unwrap()\n                            }\n                            None => 38,\n                        };\n                        let scale: i8 = match dict.get_item(\"scale\") {\n                            Some(e) => {\n                                let res: PyResult<i8> = e.extract();\n                                res.unwrap()\n                            }\n                            None => 0,\n                        };\n                        (precision, scale)\n                    }\n                    None => (38, 10),\n                };\n                DataType::Decimal128(precision, scale)\n            }\n            _ => sql_type.to_arrow()?,\n        };\n\n        Ok(DaskTypeMap {\n            sql_type,\n            data_type: d_type.into(),\n        })\n    }\n\n    fn __str__(&self) -> String {\n        format!(\"{:?}\", self.sql_type)\n    }\n\n    #[pyo3(name = \"getSqlType\")]\n    pub fn sql_type(&self) -> SqlTypeName {\n        self.sql_type.clone()\n    }\n\n    #[pyo3(name = \"getDataType\")]\n    pub fn data_type(&self) -> PyDataType {\n        self.data_type.clone()\n    }\n}\n\n#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]\n#[pyclass(name = \"PyDataType\", module = \"dask_sql\", subclass)]\npub struct PyDataType {\n    data_type: DataType,\n}\n\n#[pymethods]\nimpl PyDataType {\n    /// Gets the precision/scale represented by the PyDataType's decimal datatype\n    #[pyo3(name = \"getPrecisionScale\")]\n    pub fn get_precision_scale(&self) -> PyResult<(u8, i8)> {\n        Ok(match &self.data_type {\n            DataType::Decimal128(precision, scale) | DataType::Decimal256(precision, scale) => {\n                (*precision, *scale)\n            }\n            _ => {\n                return Err(py_type_err(format!(\n                    \"Catch all triggered in get_precision_scale, {:?}\",\n                    &self.data_type\n                )))\n            }\n        })\n    }\n}\n\nimpl From<PyDataType> for DataType {\n    fn from(data_type: PyDataType) -> DataType {\n        data_type.data_type\n    }\n}\n\nimpl From<DataType> for PyDataType {\n    fn from(data_type: DataType) -> PyDataType {\n        PyDataType { data_type }\n    }\n}\n\n/// Enumeration of the type names which can be used to construct a SQL type. Since\n/// several SQL types do not exist as Rust types and also because the Enum\n/// `SqlTypeName` is already used in the Python Dask-SQL code base this enum is used\n/// in place of just using the built-in Rust types.\n#[allow(non_camel_case_types)]\n#[allow(clippy::upper_case_acronyms)]\n#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]\n#[pyclass(name = \"SqlTypeName\", module = \"dask_sql\")]\npub enum SqlTypeName {\n    ANY,\n    ARRAY,\n    BIGINT,\n    BINARY,\n    BOOLEAN,\n    CHAR,\n    COLUMN_LIST,\n    CURSOR,\n    DATE,\n    DECIMAL,\n    DISTINCT,\n    DOUBLE,\n    DYNAMIC_STAR,\n    FLOAT,\n    GEOMETRY,\n    INTEGER,\n    INTERVAL,\n    INTERVAL_DAY,\n    INTERVAL_DAY_HOUR,\n    INTERVAL_DAY_MINUTE,\n    INTERVAL_DAY_SECOND,\n    INTERVAL_HOUR,\n    INTERVAL_HOUR_MINUTE,\n    INTERVAL_HOUR_SECOND,\n    INTERVAL_MINUTE,\n    INTERVAL_MINUTE_SECOND,\n    INTERVAL_MONTH,\n    INTERVAL_MONTH_DAY_NANOSECOND,\n    INTERVAL_SECOND,\n    INTERVAL_YEAR,\n    INTERVAL_YEAR_MONTH,\n    MAP,\n    MULTISET,\n    NULL,\n    OTHER,\n    REAL,\n    ROW,\n    SARG,\n    SMALLINT,\n    STRUCTURED,\n    SYMBOL,\n    TIME,\n    TIME_WITH_LOCAL_TIME_ZONE,\n    TIMESTAMP,\n    TIMESTAMP_WITH_LOCAL_TIME_ZONE,\n    TINYINT,\n    UNKNOWN,\n    VARBINARY,\n    VARCHAR,\n}\n\nimpl SqlTypeName {\n    pub fn to_arrow(&self) -> Result<DataType, DaskPlannerError> {\n        match self {\n            SqlTypeName::NULL => Ok(DataType::Null),\n            SqlTypeName::BOOLEAN => Ok(DataType::Boolean),\n            SqlTypeName::TINYINT => Ok(DataType::Int8),\n            SqlTypeName::SMALLINT => Ok(DataType::Int16),\n            SqlTypeName::INTEGER => Ok(DataType::Int32),\n            SqlTypeName::BIGINT => Ok(DataType::Int64),\n            SqlTypeName::REAL => Ok(DataType::Float16),\n            SqlTypeName::FLOAT => Ok(DataType::Float32),\n            SqlTypeName::DOUBLE => Ok(DataType::Float64),\n            SqlTypeName::DATE => Ok(DataType::Date64),\n            SqlTypeName::VARCHAR => Ok(DataType::Utf8),\n            _ => Err(DaskPlannerError::Internal(format!(\n                \"Cannot determine Arrow type for Dask SQL type '{self:?}'\"\n            ))),\n        }\n    }\n\n    pub fn from_arrow(arrow_type: &DataType) -> Result<Self, DaskPlannerError> {\n        match arrow_type {\n            DataType::Null => Ok(SqlTypeName::NULL),\n            DataType::Boolean => Ok(SqlTypeName::BOOLEAN),\n            DataType::Int8 => Ok(SqlTypeName::TINYINT),\n            DataType::Int16 => Ok(SqlTypeName::SMALLINT),\n            DataType::Int32 => Ok(SqlTypeName::INTEGER),\n            DataType::Int64 => Ok(SqlTypeName::BIGINT),\n            DataType::UInt8 => Ok(SqlTypeName::TINYINT),\n            DataType::UInt16 => Ok(SqlTypeName::SMALLINT),\n            DataType::UInt32 => Ok(SqlTypeName::INTEGER),\n            DataType::UInt64 => Ok(SqlTypeName::BIGINT),\n            DataType::Float16 => Ok(SqlTypeName::REAL),\n            DataType::Float32 => Ok(SqlTypeName::FLOAT),\n            DataType::Float64 => Ok(SqlTypeName::DOUBLE),\n            DataType::Time32(_) | DataType::Time64(_) => Ok(SqlTypeName::TIME),\n            DataType::Timestamp(_unit, tz) => match tz {\n                Some(_) => Ok(SqlTypeName::TIMESTAMP_WITH_LOCAL_TIME_ZONE),\n                None => Ok(SqlTypeName::TIMESTAMP),\n            },\n            DataType::Date32 => Ok(SqlTypeName::DATE),\n            DataType::Date64 => Ok(SqlTypeName::DATE),\n            DataType::Interval(unit) => match unit {\n                IntervalUnit::DayTime => Ok(SqlTypeName::INTERVAL_DAY),\n                IntervalUnit::YearMonth => Ok(SqlTypeName::INTERVAL_YEAR_MONTH),\n                IntervalUnit::MonthDayNano => Ok(SqlTypeName::INTERVAL_MONTH_DAY_NANOSECOND),\n            },\n            DataType::Binary => Ok(SqlTypeName::BINARY),\n            DataType::FixedSizeBinary(_size) => Ok(SqlTypeName::VARBINARY),\n            DataType::Utf8 => Ok(SqlTypeName::CHAR),\n            DataType::LargeUtf8 => Ok(SqlTypeName::VARCHAR),\n            DataType::Struct(_fields) => Ok(SqlTypeName::STRUCTURED),\n            DataType::Decimal128(_precision, _scale) => Ok(SqlTypeName::DECIMAL),\n            DataType::Decimal256(_precision, _scale) => Ok(SqlTypeName::DECIMAL),\n            DataType::Map(_field, _bool) => Ok(SqlTypeName::MAP),\n            _ => Err(DaskPlannerError::Internal(format!(\n                \"Cannot determine Dask SQL type for Arrow type '{arrow_type:?}'\"\n            ))),\n        }\n    }\n}\n\n#[pymethods]\nimpl SqlTypeName {\n    #[pyo3(name = \"fromString\")]\n    #[staticmethod]\n    pub fn py_from_string(input_type: &str) -> PyResult<Self> {\n        SqlTypeName::from_string(input_type).map_err(|e| e.into())\n    }\n}\n\nimpl SqlTypeName {\n    pub fn from_string(input_type: &str) -> Result<Self, DaskPlannerError> {\n        match input_type.to_uppercase().as_ref() {\n            \"ANY\" => Ok(SqlTypeName::ANY),\n            \"ARRAY\" => Ok(SqlTypeName::ARRAY),\n            \"NULL\" => Ok(SqlTypeName::NULL),\n            \"BOOLEAN\" => Ok(SqlTypeName::BOOLEAN),\n            \"COLUMN_LIST\" => Ok(SqlTypeName::COLUMN_LIST),\n            \"DISTINCT\" => Ok(SqlTypeName::DISTINCT),\n            \"CURSOR\" => Ok(SqlTypeName::CURSOR),\n            \"TINYINT\" => Ok(SqlTypeName::TINYINT),\n            \"SMALLINT\" => Ok(SqlTypeName::SMALLINT),\n            \"INT\" => Ok(SqlTypeName::INTEGER),\n            \"INTEGER\" => Ok(SqlTypeName::INTEGER),\n            \"BIGINT\" => Ok(SqlTypeName::BIGINT),\n            \"REAL\" => Ok(SqlTypeName::REAL),\n            \"FLOAT\" => Ok(SqlTypeName::FLOAT),\n            \"GEOMETRY\" => Ok(SqlTypeName::GEOMETRY),\n            \"DOUBLE\" => Ok(SqlTypeName::DOUBLE),\n            \"TIME\" => Ok(SqlTypeName::TIME),\n            \"TIME_WITH_LOCAL_TIME_ZONE\" => Ok(SqlTypeName::TIME_WITH_LOCAL_TIME_ZONE),\n            \"TIMESTAMP\" => Ok(SqlTypeName::TIMESTAMP),\n            \"TIMESTAMP_WITH_LOCAL_TIME_ZONE\" => Ok(SqlTypeName::TIMESTAMP_WITH_LOCAL_TIME_ZONE),\n            \"DATE\" => Ok(SqlTypeName::DATE),\n            \"INTERVAL\" => Ok(SqlTypeName::INTERVAL),\n            \"INTERVAL_DAY\" => Ok(SqlTypeName::INTERVAL_DAY),\n            \"INTERVAL_DAY_HOUR\" => Ok(SqlTypeName::INTERVAL_DAY_HOUR),\n            \"INTERVAL_DAY_MINUTE\" => Ok(SqlTypeName::INTERVAL_DAY_MINUTE),\n            \"INTERVAL_DAY_SECOND\" => Ok(SqlTypeName::INTERVAL_DAY_SECOND),\n            \"INTERVAL_HOUR\" => Ok(SqlTypeName::INTERVAL_HOUR),\n            \"INTERVAL_HOUR_MINUTE\" => Ok(SqlTypeName::INTERVAL_HOUR_MINUTE),\n            \"INTERVAL_HOUR_SECOND\" => Ok(SqlTypeName::INTERVAL_HOUR_SECOND),\n            \"INTERVAL_MINUTE\" => Ok(SqlTypeName::INTERVAL_MINUTE),\n            \"INTERVAL_MINUTE_SECOND\" => Ok(SqlTypeName::INTERVAL_MINUTE_SECOND),\n            \"INTERVAL_MONTH\" => Ok(SqlTypeName::INTERVAL_MONTH),\n            \"INTERVAL_SECOND\" => Ok(SqlTypeName::INTERVAL_SECOND),\n            \"INTERVAL_YEAR\" => Ok(SqlTypeName::INTERVAL_YEAR),\n            \"INTERVAL_YEAR_MONTH\" => Ok(SqlTypeName::INTERVAL_YEAR_MONTH),\n            \"MAP\" => Ok(SqlTypeName::MAP),\n            \"MULTISET\" => Ok(SqlTypeName::MULTISET),\n            \"OTHER\" => Ok(SqlTypeName::OTHER),\n            \"ROW\" => Ok(SqlTypeName::ROW),\n            \"SARG\" => Ok(SqlTypeName::SARG),\n            \"BINARY\" => Ok(SqlTypeName::BINARY),\n            \"VARBINARY\" => Ok(SqlTypeName::VARBINARY),\n            \"CHAR\" => Ok(SqlTypeName::CHAR),\n            \"VARCHAR\" | \"STRING\" => Ok(SqlTypeName::VARCHAR),\n            \"STRUCTURED\" => Ok(SqlTypeName::STRUCTURED),\n            \"SYMBOL\" => Ok(SqlTypeName::SYMBOL),\n            \"DECIMAL\" => Ok(SqlTypeName::DECIMAL),\n            \"DYNAMIC_STAT\" => Ok(SqlTypeName::DYNAMIC_STAR),\n            \"UNKNOWN\" => Ok(SqlTypeName::UNKNOWN),\n            _ => {\n                // complex data type name so use the sqlparser\n                let dialect = DaskDialect {};\n                let mut tokenizer = Tokenizer::new(&dialect, input_type);\n                let tokens = tokenizer.tokenize().map_err(DaskPlannerError::from)?;\n                let mut parser = Parser::new(&dialect).with_tokens(tokens);\n                match parser.parse_data_type().map_err(DaskPlannerError::from)? {\n                    SQLType::Decimal(_) => Ok(SqlTypeName::DECIMAL),\n                    SQLType::Binary(_) => Ok(SqlTypeName::BINARY),\n                    SQLType::Varbinary(_) => Ok(SqlTypeName::VARBINARY),\n                    SQLType::Varchar(_) | SQLType::Nvarchar(_) => Ok(SqlTypeName::VARCHAR),\n                    SQLType::Char(_) => Ok(SqlTypeName::CHAR),\n                    _ => Err(DaskPlannerError::Internal(format!(\n                        \"Cannot determine Dask SQL type for '{input_type}'\"\n                    ))),\n                }\n            }\n        }\n    }\n}\n\n#[cfg(test)]\nmod test {\n    use crate::sql::types::SqlTypeName;\n\n    #[test]\n    fn invalid_type_name() {\n        assert_eq!(\n            \"Internal Error: Cannot determine Dask SQL type for 'bob'\",\n            SqlTypeName::from_string(\"bob\")\n                .expect_err(\"invalid type name\")\n                .to_string()\n        );\n    }\n\n    #[test]\n    fn string() {\n        assert_expected(\"VARCHAR\", \"string\");\n    }\n\n    #[test]\n    fn varchar_n() {\n        assert_expected(\"VARCHAR\", \"VARCHAR(10)\");\n    }\n\n    #[test]\n    fn decimal_p_s() {\n        assert_expected(\"DECIMAL\", \"DECIMAL(10, 2)\");\n    }\n\n    fn assert_expected(expected: &str, input: &str) {\n        assert_eq!(\n            expected,\n            &format!(\"{:?}\", SqlTypeName::from_string(input).unwrap())\n        );\n    }\n}\n"
  },
  {
    "path": "src/sql.rs",
    "content": "pub mod column;\npub mod exceptions;\npub mod function;\npub mod logical;\npub mod optimizer;\npub mod parser_utils;\npub mod preoptimizer;\npub mod schema;\npub mod statement;\npub mod table;\npub mod types;\n\nuse std::{collections::HashMap, sync::Arc};\n\nuse datafusion_python::{\n    datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit},\n    datafusion_common::{\n        config::ConfigOptions,\n        tree_node::{TreeNode, TreeNodeVisitor, VisitRecursion},\n        DFSchema,\n        DataFusionError,\n    },\n    datafusion_expr::{\n        logical_plan::Extension,\n        AccumulatorFactoryFunction,\n        AggregateUDF,\n        LogicalPlan,\n        ReturnTypeFunction,\n        ScalarFunctionImplementation,\n        ScalarUDF,\n        Signature,\n        StateTypeFunction,\n        TableSource,\n        TypeSignature,\n        Volatility,\n    },\n    datafusion_sql::{\n        parser::Statement as DFStatement,\n        planner::{ContextProvider, SqlToRel},\n        ResolvedTableReference,\n        TableReference,\n    },\n};\nuse log::{debug, warn};\nuse pyo3::prelude::*;\n\nuse self::logical::{\n    create_catalog_schema::CreateCatalogSchemaPlanNode,\n    drop_schema::DropSchemaPlanNode,\n    use_schema::UseSchemaPlanNode,\n};\nuse crate::{\n    dialect::DaskDialect,\n    parser::{DaskParser, DaskStatement},\n    sql::{\n        exceptions::{py_optimization_exp, py_parsing_exp, py_runtime_err},\n        logical::{\n            alter_schema::AlterSchemaPlanNode,\n            alter_table::AlterTablePlanNode,\n            analyze_table::AnalyzeTablePlanNode,\n            create_experiment::CreateExperimentPlanNode,\n            create_model::CreateModelPlanNode,\n            create_table::CreateTablePlanNode,\n            describe_model::DescribeModelPlanNode,\n            drop_model::DropModelPlanNode,\n            export_model::ExportModelPlanNode,\n            predict_model::PredictModelPlanNode,\n            show_columns::ShowColumnsPlanNode,\n            show_models::ShowModelsPlanNode,\n            show_schemas::ShowSchemasPlanNode,\n            show_tables::ShowTablesPlanNode,\n            PyLogicalPlan,\n        },\n        preoptimizer::datetime_coercion,\n    },\n};\n\n/// DaskSQLContext is main interface used for interacting with DataFusion to\n/// parse SQL queries, build logical plans, and optimize logical plans.\n///\n/// The following example demonstrates how to generate an optimized LogicalPlan\n/// from SQL using DaskSQLContext.\n#[pyclass(name = \"DaskSQLContext\", module = \"dask_sql\", subclass)]\n#[derive(Debug, Clone)]\npub struct DaskSQLContext {\n    current_catalog: String,\n    current_schema: String,\n    schemas: HashMap<String, schema::DaskSchema>,\n    options: ConfigOptions,\n    optimizer_config: DaskSQLOptimizerConfig,\n}\n\n#[pyclass(name = \"DaskSQLOptimizerConfig\", module = \"dask_sql\", subclass)]\n#[derive(Debug, Clone)]\npub struct DaskSQLOptimizerConfig {\n    dynamic_partition_pruning: bool,\n    fact_dimension_ratio: Option<f64>,\n    max_fact_tables: Option<usize>,\n    preserve_user_order: Option<bool>,\n    filter_selectivity: Option<f64>,\n}\n\n#[pymethods]\nimpl DaskSQLOptimizerConfig {\n    #[new]\n    pub fn new(\n        dynamic_partition_pruning: bool,\n        fact_dimension_ratio: Option<f64>,\n        max_fact_tables: Option<usize>,\n        preserve_user_order: Option<bool>,\n        filter_selectivity: Option<f64>,\n    ) -> Self {\n        Self {\n            dynamic_partition_pruning,\n            fact_dimension_ratio,\n            max_fact_tables,\n            preserve_user_order,\n            filter_selectivity,\n        }\n    }\n}\n\nimpl ContextProvider for DaskSQLContext {\n    fn get_table_provider(\n        &self,\n        name: TableReference,\n    ) -> Result<Arc<dyn TableSource>, DataFusionError> {\n        let reference: ResolvedTableReference = name\n            .clone()\n            .resolve(&self.current_catalog, &self.current_schema);\n        if reference.catalog != self.current_catalog {\n            // there is a single catalog in Dask SQL\n            return Err(DataFusionError::Plan(format!(\n                \"Cannot resolve catalog '{}'\",\n                reference.catalog\n            )));\n        }\n        let schema_name = reference.clone().schema.into_owned();\n        match self.schemas.get(&schema_name) {\n            Some(schema) => {\n                let mut resp = None;\n                for table in schema.tables.values() {\n                    if table.table_name.eq(&name.table()) {\n                        // Build the Schema here\n                        let mut fields: Vec<Field> = Vec::new();\n                        // Iterate through the DaskTable instance and create a Schema instance\n                        for (column_name, column_type) in &table.columns {\n                            fields.push(Field::new(\n                                column_name,\n                                DataType::from(column_type.data_type()),\n                                true,\n                            ));\n                        }\n\n                        resp = Some(Schema::new(fields));\n                    }\n                }\n\n                // If the Table is not found return None. DataFusion will handle the error propagation\n                match resp {\n                    Some(e) => {\n                        let table_ref = &self\n                            .schemas\n                            .get(reference.schema.as_ref())\n                            .unwrap()\n                            .tables\n                            .get(reference.table.as_ref())\n                            .unwrap();\n                        let statistics = &table_ref.statistics;\n                        let filepath = &table_ref.filepath;\n                        if statistics.get_row_count() == 0.0 {\n                            Ok(Arc::new(table::DaskTableSource::new(\n                                Arc::new(e),\n                                None,\n                                filepath.clone(),\n                            )))\n                        } else {\n                            Ok(Arc::new(table::DaskTableSource::new(\n                                Arc::new(e),\n                                Some(statistics.clone()),\n                                filepath.clone(),\n                            )))\n                        }\n                    }\n                    None => Err(DataFusionError::Plan(format!(\n                        \"Table '{}.{}.{}' not found\",\n                        reference.catalog, reference.schema, reference.table\n                    ))),\n                }\n            }\n            None => Err(DataFusionError::Plan(format!(\n                \"Unable to locate Schema: '{}.{}'\",\n                reference.catalog, reference.schema\n            ))),\n        }\n    }\n\n    fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {\n        let fun: ScalarFunctionImplementation =\n            Arc::new(|_| Err(DataFusionError::NotImplemented(\"\".to_string())));\n\n        let numeric_datatypes = vec![\n            DataType::Int8,\n            DataType::Int16,\n            DataType::Int32,\n            DataType::Int64,\n            DataType::UInt8,\n            DataType::UInt16,\n            DataType::UInt32,\n            DataType::UInt64,\n            DataType::Float16,\n            DataType::Float32,\n            DataType::Float64,\n        ];\n\n        match name {\n            \"year\" => {\n                let sig = Signature::exact(\n                    vec![DataType::Timestamp(TimeUnit::Nanosecond, None)],\n                    Volatility::Immutable,\n                );\n                let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Int64)));\n                return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun)));\n            }\n            \"last_day\" => {\n                let sig = Signature::exact(\n                    vec![DataType::Timestamp(TimeUnit::Nanosecond, None)],\n                    Volatility::Immutable,\n                );\n                let rtf: ReturnTypeFunction =\n                    Arc::new(|_| Ok(Arc::new(DataType::Timestamp(TimeUnit::Nanosecond, None))));\n                return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun)));\n            }\n            \"timestampceil\" | \"timestampfloor\" => {\n                // let sig = Signature::exact(\n                //     vec![DataType::Timestamp(TimeUnit::Nanosecond, None), DataType::Date64, DataType::Utf8],\n                //     Volatility::Immutable,\n                // );\n                let sig = Signature::one_of(\n                    vec![\n                        TypeSignature::Exact(vec![DataType::Date64, DataType::Utf8]),\n                        TypeSignature::Exact(vec![\n                            DataType::Timestamp(TimeUnit::Nanosecond, None),\n                            DataType::Utf8,\n                        ]),\n                    ],\n                    Volatility::Immutable,\n                );\n                let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Date64)));\n                return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun)));\n            }\n            \"timestampadd\" => {\n                let sig = Signature::one_of(\n                    vec![\n                        TypeSignature::Exact(vec![\n                            DataType::Utf8,\n                            DataType::Int64,\n                            DataType::Date64,\n                        ]),\n                        TypeSignature::Exact(vec![\n                            DataType::Utf8,\n                            DataType::Int64,\n                            DataType::Timestamp(TimeUnit::Nanosecond, None),\n                        ]),\n                        TypeSignature::Exact(vec![\n                            DataType::Utf8,\n                            DataType::Int64,\n                            DataType::Int64,\n                        ]),\n                    ],\n                    Volatility::Immutable,\n                );\n                let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Date64)));\n                return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun)));\n            }\n            \"timestampdiff\" => {\n                let sig = Signature::one_of(\n                    vec![\n                        TypeSignature::Exact(vec![\n                            DataType::Utf8,\n                            DataType::Timestamp(TimeUnit::Nanosecond, None),\n                            DataType::Timestamp(TimeUnit::Nanosecond, None),\n                        ]),\n                        TypeSignature::Exact(vec![\n                            DataType::Utf8,\n                            DataType::Date64,\n                            DataType::Date64,\n                        ]),\n                        TypeSignature::Exact(vec![\n                            DataType::Utf8,\n                            DataType::Int64,\n                            DataType::Int64,\n                        ]),\n                    ],\n                    Volatility::Immutable,\n                );\n                let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Int64)));\n                return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun)));\n            }\n            \"dsql_totimestamp\" => {\n                let first_datatypes = vec![\n                    DataType::Int8,\n                    DataType::Int16,\n                    DataType::Int32,\n                    DataType::Int64,\n                    DataType::UInt8,\n                    DataType::UInt16,\n                    DataType::UInt32,\n                    DataType::UInt64,\n                    DataType::Utf8,\n                ];\n                let sig = generate_signatures(vec![first_datatypes, vec![DataType::Utf8]]);\n                let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Date64)));\n                return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun)));\n            }\n            \"mod\" => {\n                let sig = generate_signatures(vec![numeric_datatypes.clone(), numeric_datatypes]);\n                let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Float64)));\n                return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun)));\n            }\n            \"cbrt\" | \"cot\" | \"degrees\" | \"radians\" | \"sign\" | \"truncate\" => {\n                let sig = generate_signatures(vec![numeric_datatypes]);\n                let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Float64)));\n                return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun)));\n            }\n            \"rand\" => {\n                let sig = Signature::one_of(\n                    vec![\n                        TypeSignature::Exact(vec![]),\n                        TypeSignature::Exact(vec![DataType::Int64]),\n                    ],\n                    Volatility::Immutable,\n                );\n                let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Float64)));\n                return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun)));\n            }\n            \"rand_integer\" => {\n                let sig = Signature::one_of(\n                    vec![\n                        TypeSignature::Exact(vec![DataType::Int64]),\n                        TypeSignature::Exact(vec![DataType::Int64, DataType::Int64]),\n                    ],\n                    Volatility::Immutable,\n                );\n                let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Int64)));\n                return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun)));\n            }\n            \"extract_date\" => {\n                let sig = Signature::one_of(\n                    vec![\n                        TypeSignature::Exact(vec![DataType::Utf8, DataType::Date64]),\n                        TypeSignature::Exact(vec![\n                            DataType::Utf8,\n                            DataType::Timestamp(TimeUnit::Nanosecond, None),\n                        ]),\n                    ],\n                    Volatility::Immutable,\n                );\n                let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Date64)));\n                return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun)));\n            }\n            _ => (),\n        }\n\n        // Loop through all of the user defined functions\n        for schema in self.schemas.values() {\n            for (fun_name, func_mutex) in &schema.functions {\n                if fun_name.eq(name) {\n                    let function = func_mutex.lock().unwrap();\n                    if function.aggregation.eq(&true) {\n                        return None;\n                    }\n                    let sig = {\n                        Signature::one_of(\n                            function\n                                .return_types\n                                .keys()\n                                .map(|v| TypeSignature::Exact(v.to_vec()))\n                                .collect(),\n                            Volatility::Immutable,\n                        )\n                    };\n                    let function = function.clone();\n                    let rtf: ReturnTypeFunction = Arc::new(move |input_types| {\n                        match function.return_types.get(&input_types.to_vec()) {\n                            Some(return_type) => Ok(Arc::new(return_type.clone())),\n                            None => Err(DataFusionError::Plan(format!(\n                                \"UDF signature not found for input types {input_types:?}\"\n                            ))),\n                        }\n                    });\n                    return Some(Arc::new(ScalarUDF::new(\n                        fun_name.as_str(),\n                        &sig,\n                        &rtf,\n                        &fun,\n                    )));\n                }\n            }\n        }\n\n        None\n    }\n\n    fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {\n        let acc: AccumulatorFactoryFunction =\n            Arc::new(|_return_type| Err(DataFusionError::NotImplemented(\"\".to_string())));\n\n        let st: StateTypeFunction =\n            Arc::new(|_| Err(DataFusionError::NotImplemented(\"\".to_string())));\n\n        let numeric_datatypes = vec![\n            DataType::Int8,\n            DataType::Int16,\n            DataType::Int32,\n            DataType::Int64,\n            DataType::UInt8,\n            DataType::UInt16,\n            DataType::UInt32,\n            DataType::UInt64,\n            DataType::Float16,\n            DataType::Float32,\n            DataType::Float64,\n        ];\n\n        match name {\n            \"every\" => {\n                // let sig = generate_signatures(vec![DataType::Boolean]);\n                let sig = Signature::exact(vec![DataType::Boolean], Volatility::Immutable);\n                let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Boolean)));\n                return Some(Arc::new(AggregateUDF::new(name, &sig, &rtf, &acc, &st)));\n            }\n            \"bit_and\" | \"bit_or\" => {\n                let sig = generate_signatures(vec![numeric_datatypes]);\n                let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Int64)));\n                return Some(Arc::new(AggregateUDF::new(name, &sig, &rtf, &acc, &st)));\n            }\n            \"single_value\" => {\n                let sig = generate_signatures(vec![numeric_datatypes]);\n                let rtf: ReturnTypeFunction =\n                    Arc::new(|input_types| Ok(Arc::new(input_types[0].clone())));\n                return Some(Arc::new(AggregateUDF::new(name, &sig, &rtf, &acc, &st)));\n            }\n            \"regr_count\" => {\n                let sig = generate_signatures(vec![numeric_datatypes.clone(), numeric_datatypes]);\n                let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Int64)));\n                return Some(Arc::new(AggregateUDF::new(name, &sig, &rtf, &acc, &st)));\n            }\n            \"regr_syy\" | \"regr_sxx\" => {\n                let sig = generate_signatures(vec![numeric_datatypes.clone(), numeric_datatypes]);\n                let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Float64)));\n                return Some(Arc::new(AggregateUDF::new(name, &sig, &rtf, &acc, &st)));\n            }\n            _ => (),\n        }\n\n        // Loop through all of the user defined functions\n        for schema in self.schemas.values() {\n            for (fun_name, func_mutex) in &schema.functions {\n                if fun_name.eq(name) {\n                    let function = func_mutex.lock().unwrap();\n                    if function.aggregation.eq(&false) {\n                        return None;\n                    }\n                    let sig = {\n                        Signature::one_of(\n                            function\n                                .return_types\n                                .keys()\n                                .map(|v| TypeSignature::Exact(v.to_vec()))\n                                .collect(),\n                            Volatility::Immutable,\n                        )\n                    };\n                    let function = function.clone();\n                    let rtf: ReturnTypeFunction = Arc::new(move |input_types| {\n                        match function.return_types.get(&input_types.to_vec()) {\n                            Some(return_type) => Ok(Arc::new(return_type.clone())),\n                            None => Err(DataFusionError::Plan(format!(\n                                \"UDAF signature not found for input types {input_types:?}\"\n                            ))),\n                        }\n                    });\n                    return Some(Arc::new(AggregateUDF::new(fun_name, &sig, &rtf, &acc, &st)));\n                }\n            }\n        }\n\n        None\n    }\n\n    fn get_variable_type(&self, _: &[String]) -> Option<DataType> {\n        unimplemented!(\"RUST: get_variable_type is not yet implemented for DaskSQLContext\")\n    }\n\n    fn options(&self) -> &ConfigOptions {\n        &self.options\n    }\n\n    fn get_window_meta(\n        &self,\n        _name: &str,\n    ) -> Option<Arc<datafusion_python::datafusion_expr::WindowUDF>> {\n        unimplemented!(\"RUST: get_window_meta is not yet implemented for DaskSQLContext\")\n    }\n}\n\n#[pymethods]\nimpl DaskSQLContext {\n    #[new]\n    pub fn new(\n        default_catalog_name: &str,\n        default_schema_name: &str,\n        optimizer_config: DaskSQLOptimizerConfig,\n    ) -> Self {\n        Self {\n            current_catalog: default_catalog_name.to_owned(),\n            current_schema: default_schema_name.to_owned(),\n            schemas: HashMap::new(),\n            options: ConfigOptions::new(),\n            optimizer_config,\n        }\n    }\n\n    pub fn set_optimizer_config(&mut self, config: DaskSQLOptimizerConfig) -> PyResult<()> {\n        self.optimizer_config = config;\n        Ok(())\n    }\n\n    /// Change the current schema\n    pub fn use_schema(&mut self, schema_name: &str) -> PyResult<()> {\n        if self.schemas.contains_key(schema_name) {\n            self.current_schema = schema_name.to_owned();\n            Ok(())\n        } else {\n            Err(py_runtime_err(format!(\n                \"Schema: {schema_name} not found in DaskSQLContext\"\n            )))\n        }\n    }\n\n    /// Register a Schema with the current DaskSQLContext\n    pub fn register_schema(\n        &mut self,\n        schema_name: String,\n        schema: schema::DaskSchema,\n    ) -> PyResult<bool> {\n        self.schemas.insert(schema_name, schema);\n        Ok(true)\n    }\n\n    /// Register a DaskTable instance under the specified schema in the current DaskSQLContext\n    pub fn register_table(\n        &mut self,\n        schema_name: String,\n        table: table::DaskTable,\n    ) -> PyResult<bool> {\n        match self.schemas.get_mut(&schema_name) {\n            Some(schema) => {\n                schema.add_table(table);\n                Ok(true)\n            }\n            None => Err(py_runtime_err(format!(\n                \"Schema: {schema_name} not found in DaskSQLContext\"\n            ))),\n        }\n    }\n\n    /// Parses a SQL string into an AST presented as a Vec of Statements\n    pub fn parse_sql(&self, sql: &str) -> PyResult<Vec<statement::PyStatement>> {\n        debug!(\"parse_sql - '{}'\", sql);\n        let dd: DaskDialect = DaskDialect {};\n        match DaskParser::parse_sql_with_dialect(sql, &dd) {\n            Ok(k) => {\n                let mut statements: Vec<statement::PyStatement> = Vec::new();\n                for statement in k {\n                    statements.push(statement.into());\n                }\n                Ok(statements)\n            }\n            Err(e) => Err(py_parsing_exp(e)),\n        }\n    }\n\n    /// Creates a non-optimized Relational Algebra LogicalPlan from an AST Statement\n    pub fn logical_relational_algebra(\n        &self,\n        statement: statement::PyStatement,\n    ) -> PyResult<logical::PyLogicalPlan> {\n        self._logical_relational_algebra(statement.statement)\n            .map(|e| PyLogicalPlan {\n                original_plan: e,\n                current_node: None,\n            })\n            .map_err(py_parsing_exp)\n    }\n\n    pub fn run_preoptimizer(\n        &self,\n        existing_plan: logical::PyLogicalPlan,\n    ) -> PyResult<logical::PyLogicalPlan> {\n        if let Some(plan) = datetime_coercion(&existing_plan.original_plan) {\n            Ok(plan.into())\n        } else {\n            Ok(existing_plan)\n        }\n    }\n\n    /// Accepts an existing relational plan, `LogicalPlan`, and optimizes it\n    /// by applying a set of `optimizer` trait implementations against the\n    /// `LogicalPlan`\n    pub fn optimize_relational_algebra(\n        &self,\n        existing_plan: logical::PyLogicalPlan,\n    ) -> PyResult<logical::PyLogicalPlan> {\n        // Certain queries cannot be optimized. Ex: `EXPLAIN SELECT * FROM test` simply return those plans as is\n        let mut visitor = OptimizablePlanVisitor {};\n\n        match existing_plan.original_plan.visit(&mut visitor) {\n            Ok(valid) => {\n                match valid {\n                    VisitRecursion::Stop => {\n                        // This LogicalPlan does not support Optimization. Return original\n                        warn!(\"This LogicalPlan does not support Optimization. Returning original\");\n                        Ok(existing_plan)\n                    }\n                    _ => {\n                        let optimized_plan = optimizer::DaskSqlOptimizer::new(\n                            self.optimizer_config.fact_dimension_ratio,\n                            self.optimizer_config.max_fact_tables,\n                            self.optimizer_config.preserve_user_order,\n                            self.optimizer_config.filter_selectivity,\n                        )\n                        .optimize(existing_plan.original_plan)\n                        .map(|k| PyLogicalPlan {\n                            original_plan: k,\n                            current_node: None,\n                        })\n                        .map_err(py_optimization_exp);\n                        if let Ok(optimized_plan) = optimized_plan {\n                            if self.optimizer_config.dynamic_partition_pruning {\n                                optimizer::DaskSqlOptimizer::dynamic_partition_pruner(\n                                    self.optimizer_config.fact_dimension_ratio,\n                                )\n                                .optimize_once(optimized_plan.original_plan)\n                                .map(|k| PyLogicalPlan {\n                                    original_plan: k,\n                                    current_node: None,\n                                })\n                                .map_err(py_optimization_exp)\n                            } else {\n                                Ok(optimized_plan)\n                            }\n                        } else {\n                            optimized_plan\n                        }\n                    }\n                }\n            }\n            Err(e) => Err(py_optimization_exp(e)),\n        }\n    }\n}\n\n/// non-Python methods\nimpl DaskSQLContext {\n    /// Creates a non-optimized Relational Algebra LogicalPlan from an AST Statement\n    pub fn _logical_relational_algebra(\n        &self,\n        dask_statement: DaskStatement,\n    ) -> Result<LogicalPlan, DataFusionError> {\n        match dask_statement {\n            DaskStatement::Statement(statement) => {\n                let planner = SqlToRel::new(self);\n                planner.statement_to_plan(DFStatement::Statement(statement))\n            }\n            DaskStatement::CreateModel(create_model) => Ok(LogicalPlan::Extension(Extension {\n                node: Arc::new(CreateModelPlanNode {\n                    schema_name: create_model.schema_name,\n                    model_name: create_model.model_name,\n                    input: self._logical_relational_algebra(create_model.select)?,\n                    if_not_exists: create_model.if_not_exists,\n                    or_replace: create_model.or_replace,\n                    with_options: create_model.with_options,\n                }),\n            })),\n            DaskStatement::CreateExperiment(create_experiment) => {\n                Ok(LogicalPlan::Extension(Extension {\n                    node: Arc::new(CreateExperimentPlanNode {\n                        schema_name: create_experiment.schema_name,\n                        experiment_name: create_experiment.experiment_name,\n                        input: self._logical_relational_algebra(create_experiment.select)?,\n                        if_not_exists: create_experiment.if_not_exists,\n                        or_replace: create_experiment.or_replace,\n                        with_options: create_experiment.with_options,\n                    }),\n                }))\n            }\n            DaskStatement::PredictModel(predict_model) => Ok(LogicalPlan::Extension(Extension {\n                node: Arc::new(PredictModelPlanNode {\n                    schema_name: predict_model.schema_name,\n                    model_name: predict_model.model_name,\n                    input: self._logical_relational_algebra(predict_model.select)?,\n                }),\n            })),\n            DaskStatement::DescribeModel(describe_model) => Ok(LogicalPlan::Extension(Extension {\n                node: Arc::new(DescribeModelPlanNode {\n                    schema: Arc::new(DFSchema::empty()),\n                    schema_name: describe_model.schema_name,\n                    model_name: describe_model.model_name,\n                }),\n            })),\n            DaskStatement::CreateCatalogSchema(create_schema) => {\n                Ok(LogicalPlan::Extension(Extension {\n                    node: Arc::new(CreateCatalogSchemaPlanNode {\n                        schema: Arc::new(DFSchema::empty()),\n                        schema_name: create_schema.schema_name,\n                        if_not_exists: create_schema.if_not_exists,\n                        or_replace: create_schema.or_replace,\n                    }),\n                }))\n            }\n            DaskStatement::CreateTable(create_table) => Ok(LogicalPlan::Extension(Extension {\n                node: Arc::new(CreateTablePlanNode {\n                    schema: Arc::new(DFSchema::empty()),\n                    schema_name: create_table.schema_name,\n                    table_name: create_table.table_name,\n                    if_not_exists: create_table.if_not_exists,\n                    or_replace: create_table.or_replace,\n                    with_options: create_table.with_options,\n                }),\n            })),\n            DaskStatement::ExportModel(export_model) => Ok(LogicalPlan::Extension(Extension {\n                node: Arc::new(ExportModelPlanNode {\n                    schema: Arc::new(DFSchema::empty()),\n                    schema_name: export_model.schema_name,\n                    model_name: export_model.model_name,\n                    with_options: export_model.with_options,\n                }),\n            })),\n            DaskStatement::DropModel(drop_model) => Ok(LogicalPlan::Extension(Extension {\n                node: Arc::new(DropModelPlanNode {\n                    schema_name: drop_model.schema_name,\n                    model_name: drop_model.model_name,\n                    if_exists: drop_model.if_exists,\n                    schema: Arc::new(DFSchema::empty()),\n                }),\n            })),\n            DaskStatement::ShowSchemas(show_schemas) => Ok(LogicalPlan::Extension(Extension {\n                node: Arc::new(ShowSchemasPlanNode {\n                    schema: Arc::new(DFSchema::empty()),\n                    catalog_name: show_schemas.catalog_name,\n                    like: show_schemas.like,\n                }),\n            })),\n            DaskStatement::ShowTables(show_tables) => Ok(LogicalPlan::Extension(Extension {\n                node: Arc::new(ShowTablesPlanNode {\n                    schema: Arc::new(DFSchema::empty()),\n                    catalog_name: show_tables.catalog_name,\n                    schema_name: show_tables.schema_name,\n                }),\n            })),\n            DaskStatement::ShowColumns(show_columns) => Ok(LogicalPlan::Extension(Extension {\n                node: Arc::new(ShowColumnsPlanNode {\n                    schema: Arc::new(DFSchema::empty()),\n                    table_name: show_columns.table_name,\n                    schema_name: show_columns.schema_name,\n                }),\n            })),\n            DaskStatement::ShowModels(show_models) => Ok(LogicalPlan::Extension(Extension {\n                node: Arc::new(ShowModelsPlanNode {\n                    schema: Arc::new(DFSchema::empty()),\n                    schema_name: show_models.schema_name,\n                }),\n            })),\n            DaskStatement::DropSchema(drop_schema) => Ok(LogicalPlan::Extension(Extension {\n                node: Arc::new(DropSchemaPlanNode {\n                    schema: Arc::new(DFSchema::empty()),\n                    schema_name: drop_schema.schema_name,\n                    if_exists: drop_schema.if_exists,\n                }),\n            })),\n            DaskStatement::UseSchema(use_schema) => Ok(LogicalPlan::Extension(Extension {\n                node: Arc::new(UseSchemaPlanNode {\n                    schema: Arc::new(DFSchema::empty()),\n                    schema_name: use_schema.schema_name,\n                }),\n            })),\n            DaskStatement::AnalyzeTable(analyze_table) => Ok(LogicalPlan::Extension(Extension {\n                node: Arc::new(AnalyzeTablePlanNode {\n                    schema: Arc::new(DFSchema::empty()),\n                    table_name: analyze_table.table_name,\n                    schema_name: analyze_table.schema_name,\n                    columns: analyze_table.columns,\n                }),\n            })),\n            DaskStatement::AlterTable(alter_table) => Ok(LogicalPlan::Extension(Extension {\n                node: Arc::new(AlterTablePlanNode {\n                    schema: Arc::new(DFSchema::empty()),\n                    old_table_name: alter_table.old_table_name,\n                    new_table_name: alter_table.new_table_name,\n                    schema_name: alter_table.schema_name,\n                    if_exists: alter_table.if_exists,\n                }),\n            })),\n            DaskStatement::AlterSchema(alter_schema) => Ok(LogicalPlan::Extension(Extension {\n                node: Arc::new(AlterSchemaPlanNode {\n                    schema: Arc::new(DFSchema::empty()),\n                    old_schema_name: alter_schema.old_schema_name,\n                    new_schema_name: alter_schema.new_schema_name,\n                }),\n            })),\n        }\n    }\n}\n\n/// Visits each AST node to determine if the plan is valid for optimization or not\npub struct OptimizablePlanVisitor;\n\nimpl TreeNodeVisitor for OptimizablePlanVisitor {\n    type N = LogicalPlan;\n\n    fn pre_visit(&mut self, plan: &LogicalPlan) -> Result<VisitRecursion, DataFusionError> {\n        // If the plan contains an unsupported Node type we flag the plan as un-optimizable here\n        match plan {\n            LogicalPlan::Explain(..) => Ok(VisitRecursion::Stop),\n            _ => Ok(VisitRecursion::Continue),\n        }\n    }\n\n    fn post_visit(&mut self, _plan: &LogicalPlan) -> Result<VisitRecursion, DataFusionError> {\n        Ok(VisitRecursion::Continue)\n    }\n}\n\nfn generate_signatures(cartesian_setup: Vec<Vec<DataType>>) -> Signature {\n    let mut exact_vector = vec![];\n    let mut datatypes_iter = cartesian_setup.iter();\n    // First pass\n    if let Some(first_iter) = datatypes_iter.next() {\n        for datatype in first_iter {\n            exact_vector.push(vec![datatype.clone()]);\n        }\n    }\n    // Generate the Cartesian product\n    for iter in datatypes_iter {\n        let mut outer_temp = vec![];\n        for outer_datatype in exact_vector {\n            for inner_datatype in iter {\n                let mut inner_temp = outer_datatype.clone();\n                inner_temp.push(inner_datatype.clone());\n                outer_temp.push(inner_temp);\n            }\n        }\n        exact_vector = outer_temp;\n    }\n\n    // Create vector of TypeSignatures\n    let mut one_of_vector = vec![];\n    for vector in exact_vector.iter() {\n        one_of_vector.push(TypeSignature::Exact(vector.clone()));\n    }\n\n    Signature::one_of(one_of_vector.clone(), Volatility::Immutable)\n}\n\n#[cfg(test)]\nmod test {\n    use datafusion_python::{\n        datafusion::arrow::datatypes::DataType,\n        datafusion_expr::{Signature, TypeSignature, Volatility},\n    };\n\n    use crate::sql::generate_signatures;\n\n    #[test]\n    fn test_generate_signatures() {\n        let sig = generate_signatures(vec![\n            vec![DataType::Int64, DataType::Float64],\n            vec![DataType::Utf8, DataType::Int64],\n        ]);\n        let expected = Signature::one_of(\n            vec![\n                TypeSignature::Exact(vec![DataType::Int64, DataType::Utf8]),\n                TypeSignature::Exact(vec![DataType::Int64, DataType::Int64]),\n                TypeSignature::Exact(vec![DataType::Float64, DataType::Utf8]),\n                TypeSignature::Exact(vec![DataType::Float64, DataType::Int64]),\n            ],\n            Volatility::Immutable,\n        );\n        assert_eq!(sig, expected);\n    }\n}\n"
  },
  {
    "path": "tests/__init__.py",
    "content": ""
  },
  {
    "path": "tests/integration/__init__.py",
    "content": ""
  },
  {
    "path": "tests/integration/fixtures.py",
    "content": "import os\nimport tempfile\n\nimport dask.dataframe as dd\nimport numpy as np\nimport pandas as pd\nimport pytest\nfrom dask.datasets import timeseries as dd_timeseries\nfrom dask.distributed import Client\n\nfrom tests.utils import assert_eq, convert_nullable_columns\n\ntry:\n    import cudf\n\n    # importing to check for JVM segfault\n    import dask_cudf  # noqa: F401\n    from dask_cuda import LocalCUDACluster  # noqa: F401\nexcept ImportError:\n    cudf = None\n    dask_cudf = None\n    LocalCUDACluster = None\n\n# check if we want to run tests on a distributed client\nDISTRIBUTED_TESTS = os.getenv(\"DASK_SQL_DISTRIBUTED_TESTS\", \"False\").lower() in (\n    \"true\",\n    \"1\",\n)\n\n\n@pytest.fixture()\ndef df_simple():\n    return pd.DataFrame({\"a\": [1, 2, 3], \"b\": [1.1, 2.2, 3.3]})\n\n\n@pytest.fixture()\ndef df_wide():\n    return pd.DataFrame(\n        {\n            \"a\": [0, 1, 2],\n            \"b\": [3, 4, 5],\n            \"c\": [6, 7, 8],\n            \"d\": [9, 10, 11],\n            \"e\": [12, 13, 14],\n        }\n    )\n\n\n@pytest.fixture()\ndef df():\n    np.random.seed(42)\n    return pd.DataFrame(\n        {\n            \"a\": [1.0] * 100 + [2.0] * 200 + [3.0] * 400,\n            \"b\": 10 * np.random.rand(700),\n        }\n    )\n\n\n@pytest.fixture()\ndef department_table():\n    return pd.DataFrame({\"department_name\": [\"English\", \"Math\", \"Science\"]})\n\n\n@pytest.fixture()\ndef user_table_1():\n    return pd.DataFrame({\"user_id\": [2, 1, 2, 3], \"b\": [3, 3, 1, 3]})\n\n\n@pytest.fixture()\ndef user_table_2():\n    return pd.DataFrame({\"user_id\": [1, 1, 2, 4], \"c\": [1, 2, 3, 4]})\n\n\n@pytest.fixture()\ndef long_table():\n    return pd.DataFrame({\"a\": [0] * 100 + [1] * 101 + [2] * 103})\n\n\n@pytest.fixture()\ndef user_table_inf():\n    return pd.DataFrame({\"c\": [3, float(\"inf\"), 1]})\n\n\n@pytest.fixture()\ndef user_table_nan():\n    return pd.DataFrame({\"c\": [3, pd.NA, 1]}).astype(\"UInt8\")\n\n\n@pytest.fixture()\ndef string_table():\n    return pd.DataFrame(\n        {\n            \"a\": [\n                \"a normal string\",\n                \"%_%\",\n                \"^|()-*[]$\",\n                \"^|()-*[]$\\n%_%\\na normal string\",\n            ]\n        }\n    )\n\n\n@pytest.fixture()\ndef datetime_table():\n    return pd.DataFrame(\n        {\n            \"timezone\": pd.date_range(\n                start=\"2014-08-01 09:00\", freq=\"8H\", periods=6, tz=\"Europe/Berlin\"\n            ),\n            \"no_timezone\": pd.date_range(\n                start=\"2014-08-01 09:00\", freq=\"8H\", periods=6\n            ),\n            \"utc_timezone\": pd.date_range(\n                start=\"2014-08-01 09:00\", freq=\"8H\", periods=6, tz=\"UTC\"\n            ),\n        }\n    )\n\n\n@pytest.fixture()\ndef timeseries():\n    return dd_timeseries(freq=\"1d\").reset_index(drop=True)\n\n\n@pytest.fixture()\ndef parquet_ddf(tmpdir):\n\n    # Write simple parquet dataset\n    df = pd.DataFrame(\n        {\n            \"a\": [1, 2, 3] * 5,\n            \"b\": range(15),\n            \"c\": [\"A\"] * 15,\n            \"d\": [\n                pd.Timestamp(\"2013-08-01 23:00:00\"),\n                pd.Timestamp(\"2014-09-01 23:00:00\"),\n                pd.Timestamp(\"2015-10-01 23:00:00\"),\n            ]\n            * 5,\n            \"index\": range(15),\n        },\n    )\n    dd.from_pandas(df, npartitions=3).to_parquet(os.path.join(tmpdir, \"parquet\"))\n\n    # Read back with dask and apply WHERE query\n    return dd.read_parquet(os.path.join(tmpdir, \"parquet\"), index=\"index\")\n\n\n@pytest.fixture()\ndef gpu_user_table_1(user_table_1):\n    return cudf.from_pandas(user_table_1) if cudf else None\n\n\n@pytest.fixture()\ndef gpu_df(df):\n    return cudf.from_pandas(df) if cudf else None\n\n\n@pytest.fixture()\ndef gpu_long_table(long_table):\n    return cudf.from_pandas(long_table) if cudf else None\n\n\n@pytest.fixture()\ndef gpu_string_table(string_table):\n    return cudf.from_pandas(string_table) if cudf else None\n\n\n@pytest.fixture()\ndef gpu_datetime_table(datetime_table):\n    if cudf:\n        # TODO: remove once `from_pandas` has support for timezone-aware data\n        # https://github.com/rapidsai/cudf/issues/13611\n        df = datetime_table.copy()\n        df[\"timezone\"] = df[\"timezone\"].dt.tz_localize(None)\n        df[\"utc_timezone\"] = df[\"utc_timezone\"].dt.tz_localize(None)\n        gdf = cudf.from_pandas(df)\n        gdf[\"timezone\"] = gdf[\"timezone\"].dt.tz_localize(\n            str(datetime_table[\"timezone\"].dt.tz)\n        )\n        gdf[\"utc_timezone\"] = gdf[\"utc_timezone\"].dt.tz_localize(\n            str(datetime_table[\"utc_timezone\"].dt.tz)\n        )\n        return gdf\n    return None\n\n\n@pytest.fixture()\ndef gpu_timeseries(timeseries):\n    return timeseries.to_backend(\"cudf\") if dask_cudf else None\n\n\n@pytest.fixture()\ndef c(\n    df_simple,\n    df_wide,\n    df,\n    department_table,\n    user_table_1,\n    user_table_2,\n    long_table,\n    user_table_inf,\n    user_table_nan,\n    string_table,\n    datetime_table,\n    timeseries,\n    parquet_ddf,\n    gpu_user_table_1,\n    gpu_df,\n    gpu_long_table,\n    gpu_string_table,\n    gpu_datetime_table,\n    gpu_timeseries,\n):\n    dfs = {\n        \"df_simple\": df_simple,\n        \"df_wide\": df_wide,\n        \"df\": df,\n        \"department_table\": department_table,\n        \"user_table_1\": user_table_1,\n        \"user_table_2\": user_table_2,\n        \"long_table\": long_table,\n        \"user_table_inf\": user_table_inf,\n        \"user_table_nan\": user_table_nan,\n        \"string_table\": string_table,\n        \"datetime_table\": datetime_table,\n        \"timeseries\": timeseries,\n        \"parquet_ddf\": parquet_ddf,\n        \"gpu_user_table_1\": gpu_user_table_1,\n        \"gpu_df\": gpu_df,\n        \"gpu_long_table\": gpu_long_table,\n        \"gpu_string_table\": gpu_string_table,\n        \"gpu_datetime_table\": gpu_datetime_table,\n        \"gpu_timeseries\": gpu_timeseries,\n    }\n\n    # Lazy import, otherwise the pytest framework has problems\n    from dask_sql.context import Context\n\n    c = Context()\n    for df_name, df in dfs.items():\n        if df is None:\n            continue\n        if hasattr(df, \"npartitions\"):\n            # df is already a dask collection\n            dask_df = df\n        else:\n            dask_df = dd.from_pandas(df, npartitions=3)\n        c.create_table(df_name, dask_df)\n\n    yield c\n\n\n@pytest.fixture()\ndef temporary_data_file():\n    temporary_data_file = os.path.join(\n        tempfile.gettempdir(), os.urandom(24).hex() + \".csv\"\n    )\n\n    yield temporary_data_file\n\n    if os.path.exists(temporary_data_file):\n        os.unlink(temporary_data_file)\n\n\n@pytest.fixture()\ndef assert_query_gives_same_result(engine):\n    np.random.seed(42)\n\n    df1 = dd.from_pandas(\n        pd.DataFrame(\n            {\n                \"user_id\": np.random.choice([1, 2, 3, 4, pd.NA], 100),\n                \"a\": np.random.rand(100),\n                \"b\": np.random.randint(-10, 10, 100),\n            }\n        ),\n        npartitions=3,\n    )\n    df1[\"user_id\"] = df1[\"user_id\"].astype(\"Int64\")\n\n    df2 = dd.from_pandas(\n        pd.DataFrame(\n            {\n                \"user_id\": np.random.choice([1, 2, 3, 4], 100),\n                \"c\": np.random.randint(20, 30, 100),\n                \"d\": np.random.choice([\"a\", \"b\", \"c\", None], 100),\n            }\n        ),\n        npartitions=3,\n    )\n\n    df3 = dd.from_pandas(\n        pd.DataFrame(\n            {\n                \"s\": [\n                    \"\".join(np.random.choice([\"a\", \"B\", \"c\", \"D\"], 10))\n                    for _ in range(100)\n                ]\n                + [None]\n            }\n        ),\n        npartitions=3,\n    )\n\n    # the other is a Int64, that makes joining simpler\n    df2[\"user_id\"] = df2[\"user_id\"].astype(\"Int64\")\n\n    # add some NaNs\n    df1[\"a\"] = df1[\"a\"].apply(\n        lambda a: float(\"nan\") if a > 0.8 else a, meta=(\"a\", \"float\")\n    )\n    df1[\"b_bool\"] = df1[\"b\"].apply(\n        lambda b: pd.NA if b > 5 else b < 0, meta=(\"a\", \"bool\")\n    )\n\n    # Lazy import, otherwise the pytest framework has problems\n    from dask_sql.context import Context\n\n    c = Context()\n    c.create_table(\"df1\", df1)\n    c.create_table(\"df2\", df2)\n    c.create_table(\"df3\", df3)\n\n    df1.compute().to_sql(\"df1\", engine, index=False, if_exists=\"replace\")\n    df2.compute().to_sql(\"df2\", engine, index=False, if_exists=\"replace\")\n    df3.compute().to_sql(\"df3\", engine, index=False, if_exists=\"replace\")\n\n    def _assert_query_gives_same_result(query, sort_columns=None, **kwargs):\n        sql_result = pd.read_sql_query(query, engine)\n        dask_result = c.sql(query).compute()\n\n        # allow that the names are different\n        # as expressions are handled differently\n        sql_result.columns = dask_result.columns\n\n        sql_result = sql_result.convert_dtypes()\n        dask_result = dask_result.convert_dtypes()\n\n        convert_nullable_columns(sql_result)\n        convert_nullable_columns(dask_result)\n\n        assert_eq(\n            sql_result, dask_result, check_dtype=False, check_index=False, **kwargs\n        )\n\n    return _assert_query_gives_same_result\n\n\n@pytest.fixture()\ndef gpu_client(request):\n    # allow gpu_client to be used directly as a fixture or parametrized\n    if not hasattr(request, \"param\") or request.param:\n        with LocalCUDACluster(protocol=\"tcp\") as cluster:\n            with Client(cluster) as client:\n                yield client\n    else:\n        with Client() as client:\n            yield client\n\n\n# use session-wide distributed client if specified otherwise default to standard fixture\n@pytest.fixture(\n    scope=\"session\" if DISTRIBUTED_TESTS else \"function\", autouse=DISTRIBUTED_TESTS\n)\ndef client():\n    with Client() as client:\n        yield client\n"
  },
  {
    "path": "tests/integration/test_analyze.py",
    "content": "import dask.dataframe as dd\nimport pandas as pd\n\nfrom dask_sql.mappings import python_to_sql_type\nfrom tests.utils import assert_eq\n\n\ndef test_analyze(c, df):\n    result_df = c.sql(\"ANALYZE TABLE df COMPUTE STATISTICS FOR ALL COLUMNS\")\n\n    # extract table and compute stats with Dask manually\n    expected_df = dd.concat(\n        [\n            c.sql(\"SELECT * FROM df\").describe(),\n            pd.DataFrame(\n                {\n                    col: str(python_to_sql_type(df[col].dtype)).lower()\n                    for col in df.columns\n                },\n                index=[\"data_type\"],\n            ),\n            pd.DataFrame(\n                {col: col for col in df.columns},\n                index=[\"col_name\"],\n            ),\n        ]\n    )\n\n    assert_eq(result_df, expected_df)\n\n    result_df = c.sql(\"ANALYZE TABLE df COMPUTE STATISTICS FOR COLUMNS a\")\n\n    assert_eq(result_df, expected_df[[\"a\"]])\n"
  },
  {
    "path": "tests/integration/test_cmd.py",
    "content": "from unittest.mock import MagicMock, patch\n\nimport pytest\nfrom dask import config as dask_config\nfrom prompt_toolkit.application import create_app_session\nfrom prompt_toolkit.input import create_pipe_input\nfrom prompt_toolkit.output import DummyOutput\nfrom prompt_toolkit.shortcuts import PromptSession\n\nfrom dask_sql._compat import PIPE_INPUT_CONTEXT_MANAGER\nfrom dask_sql.cmd import _meta_commands\n\n\n@pytest.fixture(autouse=True, scope=\"function\")\ndef mock_prompt_input():\n    # TODO: remove if prompt-toolkit min version gets bumped\n    if PIPE_INPUT_CONTEXT_MANAGER:\n        with create_pipe_input() as pipe_input:\n            with create_app_session(input=pipe_input, output=DummyOutput()):\n                yield pipe_input\n    else:\n        pipe_input = create_pipe_input()\n        try:\n            with create_app_session(input=pipe_input, output=DummyOutput()):\n                yield pipe_input\n        finally:\n            pipe_input.close()\n\n\ndef _feed_cli_with_input(\n    text,\n    editing_mode=None,\n    clipboard=None,\n    history=None,\n    multiline=False,\n    check_line_ending=True,\n    key_bindings=None,\n):\n    \"\"\"\n    Create a Prompt, feed it with the given user input and return the CLI\n    object.\n    This returns a (result, Application) tuple.\n    \"\"\"\n    # If the given text doesn't end with a newline, the interface won't finish.\n    if check_line_ending:\n        assert text.endswith(\"\\r\")\n\n    inp = create_pipe_input()\n\n    try:\n        inp.send_text(text)\n        session = PromptSession(\n            input=inp,\n            output=DummyOutput(),\n            editing_mode=editing_mode,\n            history=history,\n            multiline=multiline,\n            clipboard=clipboard,\n            key_bindings=key_bindings,\n        )\n\n        result = session.prompt()\n        return session.default_buffer.document, session.app\n\n    finally:\n        inp.close()\n\n\ndef test_meta_commands(c, client, capsys):\n    _meta_commands(\"?\", context=c, client=client)\n    captured = capsys.readouterr()\n    assert \"Commands\" in captured.out\n\n    _meta_commands(\"help\", context=c, client=client)\n    captured = capsys.readouterr()\n    assert \"Commands\" in captured.out\n\n    _meta_commands(\"\\\\d?\", context=c, client=client)\n    captured = capsys.readouterr()\n    assert \"Commands\" in captured.out\n\n    _meta_commands(\"\\\\l\", context=c, client=client)\n    captured = capsys.readouterr()\n    assert \"Schemas\" in captured.out\n\n    _meta_commands(\"\\\\dt\", context=c, client=client)\n    captured = capsys.readouterr()\n    assert \"Tables\" in captured.out\n\n    _meta_commands(\"\\\\dm\", context=c, client=client)\n    captured = capsys.readouterr()\n    assert \"Models\" in captured.out\n\n    _meta_commands(\"\\\\df\", context=c, client=client)\n    captured = capsys.readouterr()\n    assert \"Functions\" in captured.out\n\n    _meta_commands(\"\\\\de\", context=c, client=client)\n    captured = capsys.readouterr()\n    assert \"Experiments\" in captured.out\n\n    c.create_schema(\"test_schema\")\n    _meta_commands(\"\\\\dss test_schema\", context=c, client=client)\n    assert c.schema_name == \"test_schema\"\n\n    _meta_commands(\"\\\\dss not_exists\", context=c, client=client)\n    captured = capsys.readouterr()\n    assert \"Schema not_exists not available\\n\" == captured.out\n\n    # FIXME: Revert to 8787 once https://github.com/dask/distributed/issues/8071 is fixed\n    with pytest.raises(\n        OSError,\n        match=\"Timed out .* to tcp://localhost:8788 after 5 s\",\n    ):\n        with dask_config.set({\"distributed.comm.timeouts.connect\": 5}):\n            client = _meta_commands(\"\\\\dsc localhost:8788\", context=c, client=client)\n            assert client.scheduler.__dict__[\"addr\"] == \"localhost:8788\"\n\n\ndef test_connection_info(c, client, capsys):\n    dummy_client = MagicMock()\n    dummy_client.scheduler.__dict__[\"addr\"] = \"somewhereonearth:8787\"\n    dummy_client.cluster.worker = [\"worker1\", \"worker2\"]\n\n    _meta_commands(\"\\\\conninfo\", context=c, client=dummy_client)\n    captured = capsys.readouterr()\n    assert \"somewhereonearth\" in captured.out\n\n\ndef test_quit(c, client, capsys):\n    dummy_client = MagicMock()\n    with patch(\"sys.exit\", return_value=lambda: \"exit\"):\n        _meta_commands(\"quit\", context=c, client=dummy_client)\n        captured = capsys.readouterr()\n        assert captured.out == \"Quitting dask-sql ...\\n\"\n\n\ndef test_non_meta_commands(c, client, capsys):\n    _meta_commands(\"\\\\x\", context=c, client=client)\n    captured = capsys.readouterr()\n    assert (\n        \"The meta command \\\\x not available, please use commands from below list\"\n        in captured.out\n    )\n\n    res = _meta_commands(\"Select 42 as answer\", context=c, client=client)\n    captured = capsys.readouterr()\n    assert res is False\n"
  },
  {
    "path": "tests/integration/test_compatibility.py",
    "content": "\"\"\"\nThe tests in this module are taken from\nthe fugue-sql module to test the compatibility\nwith their \"understanding\" of SQL\nThey run randomized tests and compare with sqlite.\n\nThere are some changes compared to the fugueSQL\ntests, especially when it comes to sort order:\ndask-sql does not enforce a specific order after groupby\n\"\"\"\n\nimport sqlite3\nfrom datetime import datetime, timedelta\n\nimport dask.config\nimport numpy as np\nimport pandas as pd\nimport pytest\n\nfrom dask_sql import Context\nfrom dask_sql.utils import ParsingException\nfrom tests.utils import assert_eq, convert_nullable_columns, skipif_dask_expr_enabled\n\n\ndef eq_sqlite(sql, **dfs):\n    c = Context()\n    engine = sqlite3.connect(\":memory:\")\n\n    for name, df in dfs.items():\n        c.create_table(name, df)\n        df.to_sql(name, engine, index=False)\n\n    dask_result = c.sql(sql).compute().convert_dtypes()\n    sqlite_result = pd.read_sql(sql, engine).convert_dtypes()\n\n    convert_nullable_columns(dask_result)\n    convert_nullable_columns(sqlite_result)\n\n    datetime_cols = dask_result.select_dtypes(\n        include=[\"datetime64[ns]\"]\n    ).columns.tolist()\n    for col in datetime_cols:\n        sqlite_result[col] = pd.to_datetime(sqlite_result[col])\n\n    sqlite_result = sqlite_result.astype(dask_result.dtypes)\n\n    assert_eq(dask_result, sqlite_result, check_dtype=False, check_index=False)\n\n\ndef make_rand_df(size: int, **kwargs):\n    np.random.seed(0)\n    data = {}\n    for k, v in kwargs.items():\n        if not isinstance(v, tuple):\n            v = (v, 0.0)\n        dt, null_ct = v[0], v[1]\n        if dt is int:\n            s = np.random.randint(10, size=size)\n        elif dt is bool:\n            s = np.where(np.random.randint(2, size=size), True, False)\n        elif dt is float:\n            s = np.random.rand(size)\n        elif dt is str:\n            r = [f\"ssssss{x}\" for x in range(10)]\n            c = np.random.randint(10, size=size)\n            s = np.array([r[x] for x in c])\n        elif dt is pd.StringDtype:\n            r = [f\"ssssss{x}\" for x in range(10)]\n            c = np.random.randint(10, size=size)\n            s = np.array([r[x] for x in c])\n            s = pd.array(s, dtype=\"string\")\n        elif dt is datetime:\n            rt = [datetime(2020, 1, 1) + timedelta(days=x) for x in range(10)]\n            c = np.random.randint(10, size=size)\n            s = np.array([rt[x] for x in c])\n        else:\n            raise NotImplementedError\n        ps = pd.Series(s)\n        if null_ct > 0:\n            idx = np.random.choice(size, null_ct, replace=False).tolist()\n            ps[idx] = None\n        data[k] = ps\n    return pd.DataFrame(data)\n\n\ndef test_basic_select_from():\n    df = make_rand_df(5, a=(int, 2), b=(str, 3), c=(float, 4))\n    eq_sqlite(\"SELECT 1 AS a, 1.5 AS b, 'x' AS c\")\n    eq_sqlite(\"SELECT 1+2 AS a, 1.5*3 AS b, 'x' AS c\")\n    eq_sqlite(\"SELECT * FROM a\", a=df)\n    eq_sqlite(\"SELECT * FROM a AS x\", a=df)\n    eq_sqlite(\"SELECT b AS bb, a+1-2*3.0/4 AS cc, x.* FROM a AS x\", a=df)\n    eq_sqlite(\"SELECT *, 1 AS x, 2.5 AS y, 'z' AS z FROM a AS x\", a=df)\n    eq_sqlite(\"SELECT *, -(1.0+a)/3 AS x, +(2.5) AS y FROM a AS x\", a=df)\n\n\ndef test_case_when():\n    a = make_rand_df(100, a=(int, 20), b=(str, 30), c=(float, 40))\n    eq_sqlite(\n        \"\"\"\n        SELECT a,b,c,\n            CASE\n                WHEN a<10 THEN a+3\n                WHEN c<0.5 THEN a+5\n                ELSE (1+2)*3 + a\n            END AS d\n        FROM a\n        \"\"\",\n        a=a,\n    )\n\n\ndef test_drop_duplicates():\n    # simplest\n    a = make_rand_df(100, a=int, b=int)\n    eq_sqlite(\n        \"\"\"\n        SELECT DISTINCT b, a FROM a\n        ORDER BY a NULLS LAST, b NULLS FIRST\n        \"\"\",\n        a=a,\n    )\n    # mix of number and nan\n    a = make_rand_df(100, a=(int, 50), b=(int, 50))\n    eq_sqlite(\n        \"\"\"\n        SELECT DISTINCT b, a FROM a\n        ORDER BY a NULLS LAST, b NULLS FIRST\n        \"\"\",\n        a=a,\n    )\n    # mix of number and string and nulls\n    a = make_rand_df(100, a=(int, 50), b=(str, 50), c=float)\n    eq_sqlite(\n        \"\"\"\n        SELECT DISTINCT b, a FROM a\n        ORDER BY a NULLS LAST, b NULLS FIRST\n        \"\"\",\n        a=a,\n    )\n\n\ndef test_order_by_no_limit():\n    a = make_rand_df(100, a=(int, 50), b=(str, 50), c=float)\n    eq_sqlite(\n        \"\"\"\n        SELECT DISTINCT b, a FROM a\n        ORDER BY a NULLS LAST, b NULLS FIRST\n        \"\"\",\n        a=a,\n    )\n\n\ndef test_order_by_limit():\n    a = make_rand_df(100, a=(int, 50), b=(str, 50), c=float)\n    eq_sqlite(\n        \"\"\"\n        SELECT DISTINCT b, a FROM a LIMIT 0\n        \"\"\",\n        a=a,\n    )\n    eq_sqlite(\n        \"\"\"\n        SELECT DISTINCT b, a FROM a ORDER BY a NULLS FIRST, b NULLS FIRST LIMIT 2\n        \"\"\",\n        a=a,\n    )\n    eq_sqlite(\n        \"\"\"\n        SELECT b, a FROM a\n            ORDER BY a NULLS LAST, b NULLS FIRST LIMIT 10\n        \"\"\",\n        a=a,\n    )\n\n\ndef test_where():\n    df = make_rand_df(100, a=(int, 30), b=(str, 30), c=(float, 30))\n    eq_sqlite(\"SELECT * FROM a WHERE TRUE OR TRUE\", a=df)\n    eq_sqlite(\"SELECT * FROM a WHERE TRUE AND TRUE\", a=df)\n    eq_sqlite(\"SELECT * FROM a WHERE FALSE OR FALSE\", a=df)\n    eq_sqlite(\"SELECT * FROM a WHERE FALSE AND FALSE\", a=df)\n\n    eq_sqlite(\"SELECT * FROM a WHERE TRUE OR b<='ssssss8'\", a=df)\n    eq_sqlite(\"SELECT * FROM a WHERE TRUE AND b<='ssssss8'\", a=df)\n    eq_sqlite(\"SELECT * FROM a WHERE FALSE OR b<='ssssss8'\", a=df)\n    eq_sqlite(\"SELECT * FROM a WHERE FALSE AND b<='ssssss8'\", a=df)\n    eq_sqlite(\"SELECT * FROM a WHERE a=10 OR b<='ssssss8'\", a=df)\n    eq_sqlite(\"SELECT * FROM a WHERE c IS NOT NULL OR (a<5 AND b IS NOT NULL)\", a=df)\n\n    df = make_rand_df(100, a=(float, 30), b=(float, 30), c=(float, 30))\n    eq_sqlite(\"SELECT * FROM a WHERE a<0.5 AND b<0.5 AND c<0.5\", a=df)\n    eq_sqlite(\"SELECT * FROM a WHERE a<0.5 OR b<0.5 AND c<0.5\", a=df)\n    eq_sqlite(\"SELECT * FROM a WHERE a IS NULL OR (b<0.5 AND c<0.5)\", a=df)\n    eq_sqlite(\"SELECT * FROM a WHERE a*b IS NULL OR (b*c<0.5 AND c*a<0.5)\", a=df)\n\n\ndef test_in_between():\n    df = make_rand_df(10, a=(int, 3), b=(str, 3))\n    eq_sqlite(\"SELECT * FROM a WHERE a IN (2,4,6)\", a=df)\n    eq_sqlite(\"SELECT * FROM a WHERE a BETWEEN 2 AND 4+1\", a=df)\n    eq_sqlite(\"SELECT * FROM a WHERE a NOT IN (2,4,6) AND a IS NOT NULL\", a=df)\n    eq_sqlite(\"SELECT * FROM a WHERE a NOT BETWEEN 2 AND 4+1 AND a IS NOT NULL\", a=df)\n    eq_sqlite(\n        \"SELECT * FROM a WHERE SUBSTR(b,1,2) IN ('ss','s') AND a NOT BETWEEN 3 AND 5 and a IS NOT NULL\",\n        a=df,\n    )\n\n\ndef test_join_inner():\n    a = make_rand_df(100, a=(int, 40), b=(str, 40), c=(float, 40))\n    b = make_rand_df(80, d=(float, 10), a=(int, 10), b=(str, 10))\n    eq_sqlite(\n        \"\"\"\n        SELECT\n        a.*, d, d*c AS x\n        FROM a\n        INNER JOIN b ON a.a=b.a AND a.b=b.b\n        ORDER BY a.a NULLS FIRST, a.b NULLS FIRST, a.c NULLS FIRST, d NULLS FIRST\n        \"\"\",\n        a=a,\n        b=b,\n    )\n\n\ndef test_join_left():\n    a = make_rand_df(100, a=(int, 40), b=(str, 40), c=(float, 40))\n    b = make_rand_df(80, d=(float, 10), a=(int, 10), b=(str, 10))\n    eq_sqlite(\n        \"\"\"\n        SELECT\n        a.*, d, d*c AS x\n        FROM a LEFT JOIN b ON a.a=b.a AND a.b=b.b\n        ORDER BY a.a NULLS FIRST, a.b NULLS FIRST, a.c NULLS FIRST, d NULLS FIRST\n        \"\"\",\n        a=a,\n        b=b,\n    )\n\n\ndef test_join_cross():\n    a = make_rand_df(10, a=(int, 4), b=(str, 4), c=(float, 4))\n    b = make_rand_df(20, dd=(float, 1), aa=(int, 1), bb=(str, 1))\n    eq_sqlite(\n        \"\"\"\n        SELECT * FROM a\n            CROSS JOIN b\n        ORDER BY a.a NULLS FIRST, a.b NULLS FIRST, a.c NULLS FIRST, dd NULLS FIRST\n        \"\"\",\n        a=a,\n        b=b,\n    )\n\n\ndef test_join_multi():\n    a = make_rand_df(100, a=(int, 40), b=(str, 40), c=(float, 40))\n    b = make_rand_df(80, d=(float, 10), a=(int, 10), b=(str, 10))\n    c = make_rand_df(80, dd=(float, 10), a=(int, 10), b=(str, 10))\n    eq_sqlite(\n        \"\"\"\n        SELECT a.*,d,dd FROM a\n            INNER JOIN b ON a.a=b.a AND a.b=b.b\n            INNER JOIN c ON a.a=c.a AND c.b=b.b\n        ORDER BY a.a NULLS FIRST, a.b NULLS FIRST, a.c NULLS FIRST, dd NULLS FIRST, d NULLS FIRST\n        \"\"\",\n        a=a,\n        b=b,\n        c=c,\n    )\n\n\ndef test_single_agg_count_no_group_by():\n    a = make_rand_df(\n        100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40)\n    )\n    eq_sqlite(\n        \"\"\"\n        SELECT\n            COUNT(a) AS c_a,\n            COUNT(DISTINCT a) AS cd_a\n        FROM a\n        \"\"\",\n        a=a,\n    )\n\n\ndef test_multi_agg_count_no_group_by():\n    a = make_rand_df(\n        100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40)\n    )\n    eq_sqlite(\n        \"\"\"\n        SELECT\n            COUNT(a) AS c_a,\n            COUNT(DISTINCT a) AS cd_a,\n            COUNT(b) AS c_b,\n            COUNT(DISTINCT b) AS cd_b,\n            COUNT(c) AS c_c,\n            COUNT(DISTINCT c) AS cd_c,\n            COUNT(d) AS c_d,\n            COUNT(DISTINCT d) AS cd_d,\n            COUNT(e) AS c_e,\n            COUNT(DISTINCT e) AS cd_e\n        FROM a\n        \"\"\",\n        a=a,\n    )\n\n\ndef test_multi_agg_count_no_group_by_dupe_distinct():\n    a = make_rand_df(\n        100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40)\n    )\n    # note that this test repeats the expression `COUNT(DISTINCT a)`\n    eq_sqlite(\n        \"\"\"\n        SELECT\n            COUNT(a) AS c_a,\n            COUNT(DISTINCT a) AS cd_a,\n            COUNT(b) AS c_b,\n            COUNT(DISTINCT b) AS cd_b,\n            COUNT(c) AS c_c,\n            COUNT(DISTINCT c) AS cd_c,\n            COUNT(d) AS c_d,\n            COUNT(DISTINCT d) AS cd_d,\n            COUNT(e) AS c_e,\n            COUNT(DISTINCT a) AS cd_e\n        FROM a\n        \"\"\",\n        a=a,\n    )\n\n\ndef test_agg_count_distinct_group_by():\n    a = make_rand_df(\n        100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40)\n    )\n    eq_sqlite(\n        \"\"\"\n        SELECT\n            a,\n            COUNT(DISTINCT b) AS cd_b\n        FROM a\n        GROUP BY a\n        ORDER BY a NULLS FIRST\n        \"\"\",\n        a=a,\n    )\n\n\ndef test_agg_count_no_group_by():\n    a = make_rand_df(\n        100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40)\n    )\n    eq_sqlite(\n        \"\"\"\n        SELECT\n            COUNT(a) AS cd_a\n        FROM a\n        \"\"\",\n        a=a,\n    )\n\n\ndef test_agg_count_distinct_no_group_by():\n    a = make_rand_df(\n        100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40)\n    )\n    eq_sqlite(\n        \"\"\"\n        SELECT\n            COUNT(DISTINCT a) AS cd_a\n        FROM a\n        \"\"\",\n        a=a,\n    )\n\n\ndef test_agg_count():\n    a = make_rand_df(\n        100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40)\n    )\n    # note that this test repeats the expression `COUNT(DISTINCT a)`\n    eq_sqlite(\n        \"\"\"\n        SELECT\n            a, b, a+1 AS c,\n            COUNT(c) AS c_c,\n            COUNT(DISTINCT c) AS cd_c,\n            COUNT(d) AS c_d,\n            COUNT(DISTINCT d) AS cd_d,\n            COUNT(e) AS c_e,\n            COUNT(DISTINCT a) AS cd_e\n        FROM a GROUP BY a, b ORDER BY\n            a NULLS FIRST,\n            b NULLS FIRST\n        \"\"\",\n        a=a,\n    )\n\n\ndef test_agg_sum_avg_no_group_by():\n    eq_sqlite(\n        \"\"\"\n        SELECT\n            SUM(a) AS sum_a,\n            AVG(a) AS avg_a\n        FROM a\n        \"\"\",\n        a=pd.DataFrame({\"a\": [float(\"2.3\")]}),\n    )\n    a = make_rand_df(\n        100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40)\n    )\n    eq_sqlite(\n        \"\"\"\n        SELECT\n            SUM(a) AS sum_a,\n            AVG(a) AS avg_a,\n            SUM(c) AS sum_c,\n            AVG(c) AS avg_c,\n            SUM(e) AS sum_e,\n            AVG(e) AS avg_e,\n            SUM(a)+AVG(e) AS mix_1,\n            SUM(a+e) AS mix_2\n        FROM a\n        \"\"\",\n        a=a,\n    )\n\n\ndef test_agg_sum_avg():\n    a = make_rand_df(\n        100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40)\n    )\n    eq_sqlite(\n        \"\"\"\n        SELECT\n            a,b, a+1 AS c,\n            SUM(c) AS sum_c,\n            AVG(c) AS avg_c,\n            SUM(e) AS sum_e,\n            AVG(e) AS avg_e,\n            SUM(a)+AVG(e) AS mix_1,\n            SUM(a+e) AS mix_2\n        FROM a GROUP BY a, b ORDER BY\n            a NULLS FIRST,\n            b NULLS FIRST\n        \"\"\",\n        a=a,\n    )\n\n\ndef test_agg_min_max_no_group_by():\n    a = make_rand_df(\n        100,\n        a=(int, 50),\n        b=(str, 50),\n        c=(int, 30),\n        d=(str, 40),\n        e=(float, 40),\n        f=(pd.StringDtype, 40),\n        g=(datetime, 40),\n    )\n    eq_sqlite(\n        \"\"\"\n        SELECT\n            MIN(a) AS min_a,\n            MAX(a) AS max_a,\n            MIN(b) AS min_b,\n            MAX(b) AS max_b,\n            MIN(c) AS min_c,\n            MAX(c) AS max_c,\n            MIN(d) AS min_d,\n            MAX(d) AS max_d,\n            MIN(e) AS min_e,\n            MAX(e) AS max_e,\n            MIN(f) as min_f,\n            MAX(f) as max_f,\n            MIN(g) as min_g,\n            MAX(g) as max_g,\n            MIN(a+e) AS mix_1,\n            MIN(a)+MIN(e) AS mix_2\n        FROM a\n        \"\"\",\n        a=a,\n    )\n\n\ndef test_agg_min_max():\n    a = make_rand_df(\n        100,\n        a=(int, 50),\n        b=(str, 50),\n        c=(int, 30),\n        d=(str, 40),\n        e=(float, 40),\n        f=(pd.StringDtype, 40),\n        g=(datetime, 40),\n    )\n    eq_sqlite(\n        \"\"\"\n        SELECT\n            a, b, a+1 AS c,\n            MIN(c) AS min_c,\n            MAX(c) AS max_c,\n            MIN(d) AS min_d,\n            MAX(d) AS max_d,\n            MIN(e) AS min_e,\n            MAX(e) AS max_e,\n            MIN(f) AS min_f,\n            MAX(f) AS max_f,\n            MIN(g) AS min_g,\n            MAX(g) AS max_g,\n            MIN(a+e) AS mix_1,\n            MIN(a)+MIN(e) AS mix_2\n        FROM a GROUP BY a, b ORDER BY\n            a NULLS FIRST,\n            b NULLS FIRST\n        \"\"\",\n        a=a,\n    )\n\n\ndef test_window_row_number():\n    a = make_rand_df(10, a=int, b=(float, 5))\n    eq_sqlite(\n        \"\"\"\n        SELECT *,\n            ROW_NUMBER() OVER (ORDER BY a ASC, b DESC NULLS FIRST) AS a1,\n            ROW_NUMBER() OVER (ORDER BY a ASC, b DESC NULLS LAST) AS a2,\n            ROW_NUMBER() OVER (ORDER BY a ASC, b ASC NULLS FIRST) AS a3,\n            ROW_NUMBER() OVER (ORDER BY a ASC, b ASC NULLS LAST) AS a4,\n            ROW_NUMBER() OVER (PARTITION BY a ORDER BY a,b DESC NULLS FIRST) AS a5\n        FROM a\n        ORDER BY a, b NULLS FIRST\n        \"\"\",\n        a=a,\n    )\n\n    a = make_rand_df(100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=float)\n    eq_sqlite(\n        \"\"\"\n        SELECT *,\n            ROW_NUMBER() OVER (ORDER BY a ASC NULLS LAST, b DESC NULLS FIRST, e) AS a1,\n            ROW_NUMBER() OVER (ORDER BY a ASC NULLS FIRST, b DESC NULLS LAST, e) AS a2,\n            ROW_NUMBER() OVER (PARTITION BY a ORDER BY a NULLS FIRST, b DESC NULLS LAST, e) AS a3,\n            ROW_NUMBER() OVER (PARTITION BY a,c ORDER BY a NULLS FIRST, b DESC NULLS LAST, e) AS a4\n        FROM a\n        ORDER BY a NULLS FIRST, b NULLS FIRST, c NULLS FIRST, d NULLS FIRST, e\n        \"\"\",\n        a=a,\n    )\n\n\ndef test_window_row_number_partition_by():\n    a = make_rand_df(100, a=int, b=(float, 50))\n    eq_sqlite(\n        \"\"\"\n        SELECT *,\n            ROW_NUMBER() OVER (PARTITION BY a ORDER BY a, b DESC NULLS FIRST) AS a5\n        FROM a\n        ORDER BY a, b NULLS FIRST, a5\n        \"\"\",\n        a=a,\n    )\n\n    a = make_rand_df(100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=float)\n    eq_sqlite(\n        \"\"\"\n        SELECT *,\n            ROW_NUMBER() OVER (PARTITION BY a ORDER BY a NULLS FIRST, b DESC NULLS FIRST, e) AS a3,\n            ROW_NUMBER() OVER (PARTITION BY a,c ORDER BY a NULLS FIRST, b DESC NULLS FIRST, e) AS a4\n        FROM a\n        ORDER BY a NULLS FIRST, b NULLS FIRST, c NULLS FIRST, d NULLS FIRST, e\n        \"\"\",\n        a=a,\n    )\n\n\n@pytest.mark.xfail(\n    reason=\"Need to implement rank/lead/lag window functions, see https://github.com/dask-contrib/dask-sql/issues/878\"\n)\ndef test_window_ranks():\n    a = make_rand_df(100, a=int, b=(float, 50), c=(str, 50))\n    eq_sqlite(\n        \"\"\"\n        SELECT *,\n            RANK() OVER (PARTITION BY a ORDER BY b DESC NULLS FIRST, c) AS a1,\n            DENSE_RANK() OVER (ORDER BY a ASC, b DESC NULLS LAST, c DESC) AS a2,\n            PERCENT_RANK() OVER (ORDER BY a ASC, b ASC NULLS LAST, c) AS a4\n        FROM a\n        \"\"\",\n        a=a,\n    )\n\n\n@pytest.mark.xfail(\n    reason=\"Need to implement rank/lead/lag window functions, see https://github.com/dask-contrib/dask-sql/issues/878\"\n)\ndef test_window_ranks_partition_by():\n    a = make_rand_df(100, a=int, b=(float, 50), c=(str, 50))\n    eq_sqlite(\n        \"\"\"\n        SELECT *,\n            RANK() OVER (PARTITION BY a ORDER BY b DESC NULLS FIRST, c) AS a1,\n            DENSE_RANK() OVER\n                (PARTITION BY a ORDER BY a ASC, b DESC NULLS LAST, c DESC)\n                AS a2,\n            PERCENT_RANK() OVER\n                (PARTITION BY a ORDER BY a ASC, b ASC NULLS LAST, c) AS a4\n        FROM a\n        \"\"\",\n        a=a,\n    )\n\n\n@pytest.mark.xfail(\n    reason=\"Need to implement rank/lead/lag window functions, see https://github.com/dask-contrib/dask-sql/issues/878\"\n)\ndef test_window_lead_lag():\n    a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50))\n    eq_sqlite(\n        \"\"\"\n        SELECT\n            LEAD(b,1) OVER (ORDER BY a) AS a1,\n            LEAD(b,2,10) OVER (ORDER BY a) AS a2,\n            LEAD(b,1) OVER (PARTITION BY c ORDER BY a) AS a3,\n            LEAD(b,1) OVER (PARTITION BY c ORDER BY b, a ASC NULLS LAST) AS a5,\n\n            LAG(b,1) OVER (ORDER BY a) AS b1,\n            LAG(b,2,10) OVER (ORDER BY a) AS b2,\n            LAG(b,1) OVER (PARTITION BY c ORDER BY a) AS b3,\n            LAG(b,1) OVER (PARTITION BY c ORDER BY b, a ASC NULLS LAST) AS b5\n        FROM a\n        \"\"\",\n        a=a,\n    )\n\n\n@pytest.mark.xfail(\n    reason=\"Need to implement rank/lead/lag window functions, see https://github.com/dask-contrib/dask-sql/issues/878\"\n)\ndef test_window_lead_lag_partition_by():\n    a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50))\n    eq_sqlite(\n        \"\"\"\n        SELECT\n            LEAD(b,1,10) OVER (PARTITION BY c ORDER BY a) AS a3,\n            LEAD(b,1) OVER (PARTITION BY c ORDER BY b, a ASC NULLS LAST) AS a5,\n\n            LAG(b,1) OVER (PARTITION BY c ORDER BY a) AS b3,\n            LAG(b,1) OVER (PARTITION BY c ORDER BY b, a ASC NULLS LAST) AS b5\n        FROM a\n        \"\"\",\n        a=a,\n    )\n\n\ndef test_window_sum_avg():\n    a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50))\n    for func in [\"SUM\", \"AVG\"]:\n        eq_sqlite(\n            f\"\"\"\n            SELECT a,b,\n                {func}(b) OVER () AS a1,\n                {func}(b) OVER (PARTITION BY c) AS a2,\n                {func}(b+a) OVER (PARTITION BY c,b) AS a3,\n                {func}(b+a) OVER (PARTITION BY b ORDER BY a NULLS FIRST\n                    ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS a4,\n                {func}(b+a) OVER (PARTITION BY b ORDER BY a DESC NULLS FIRST\n                    ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS a5,\n                {func}(b+a) OVER (PARTITION BY b ORDER BY a NULLS FIRST\n                    ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)\n                    AS a6\n            FROM a\n            ORDER BY a NULLS FIRST, b NULLS FIRST, c NULLS FIRST\n            \"\"\",\n            a=a,\n        )\n        # irregular windows\n        eq_sqlite(\n            f\"\"\"\n            SELECT a,b,\n                {func}(b) OVER (PARTITION BY b ORDER BY a DESC NULLS FIRST\n                    ROWS BETWEEN 2 PRECEDING AND 1 PRECEDING) AS a6,\n                {func}(b) OVER (PARTITION BY b ORDER BY a DESC NULLS FIRST\n                    ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING) AS a7,\n                {func}(b) OVER (PARTITION BY b ORDER BY a DESC NULLS FIRST\n                    ROWS BETWEEN 2 PRECEDING AND UNBOUNDED FOLLOWING) AS a8\n            FROM a\n            ORDER BY a NULLS FIRST, b NULLS FIRST, c NULLS FIRST\n            \"\"\",\n            a=a,\n        )\n\n\ndef test_window_sum_avg_partition_by():\n    a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50))\n    for func in [\"SUM\", \"AVG\"]:\n        eq_sqlite(\n            f\"\"\"\n            SELECT a,b,\n                {func}(b+a) OVER (PARTITION BY c,b) AS a3,\n                {func}(b+a) OVER (PARTITION BY b ORDER BY a NULLS FIRST\n                    ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS a4,\n                {func}(b+a) OVER (PARTITION BY b ORDER BY a DESC NULLS FIRST\n                    ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS a5,\n                {func}(b+a) OVER (PARTITION BY b ORDER BY a NULLS FIRST\n                    ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)\n                    AS a6\n            FROM a\n            ORDER BY a NULLS FIRST, b NULLS FIRST, c NULLS FIRST\n            \"\"\",\n            a=a,\n        )\n        # irregular windows\n        eq_sqlite(\n            f\"\"\"\n            SELECT a,b,\n                {func}(b) OVER (PARTITION BY b ORDER BY a DESC NULLS FIRST\n                    ROWS BETWEEN 2 PRECEDING AND 1 PRECEDING) AS a6,\n                {func}(b) OVER (PARTITION BY b ORDER BY a DESC NULLS FIRST\n                    ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING) AS a7,\n                {func}(b) OVER (PARTITION BY b ORDER BY a DESC NULLS FIRST\n                    ROWS BETWEEN 2 PRECEDING AND UNBOUNDED FOLLOWING) AS a8\n            FROM a\n            ORDER BY a NULLS FIRST, b NULLS FIRST, c NULLS FIRST\n            \"\"\",\n            a=a,\n        )\n\n\ndef test_window_min_max():\n    for func in [\"MIN\", \"MAX\"]:\n        a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50))\n        eq_sqlite(\n            f\"\"\"\n            SELECT a,b,\n                {func}(b) OVER () AS a1,\n                {func}(b) OVER (PARTITION BY c) AS a2,\n                {func}(b+a) OVER (PARTITION BY c,b) AS a3,\n                {func}(b+a) OVER (PARTITION BY b ORDER BY a\n                    ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS a4,\n                {func}(b+a) OVER (PARTITION BY b ORDER BY a DESC\n                    ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS a5,\n                {func}(b+a) OVER (PARTITION BY b ORDER BY a\n                    ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)\n                    AS a6\n            FROM a\n            ORDER BY a NULLS FIRST, b NULLS FIRST, c NULLS FIRST\n            \"\"\",\n            a=a,\n        )\n        # irregular windows\n        eq_sqlite(\n            f\"\"\"\n            SELECT a,b,\n                {func}(b) OVER (ORDER BY a DESC\n                    ROWS BETWEEN 2 PRECEDING AND 1 PRECEDING) AS a6,\n                {func}(b) OVER (ORDER BY a DESC\n                    ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING) AS a7,\n                {func}(b) OVER (ORDER BY a DESC\n                    ROWS BETWEEN 2 PRECEDING AND UNBOUNDED FOLLOWING) AS a8\n            FROM a\n            ORDER BY a NULLS FIRST, b NULLS FIRST, c NULLS FIRST\n            \"\"\",\n            a=a,\n        )\n        b = make_rand_df(10, a=float, b=(int, 0), c=(str, 0))\n        eq_sqlite(\n            f\"\"\"\n            SELECT a,b,\n                {func}(b) OVER (PARTITION BY b ORDER BY a DESC\n                    ROWS BETWEEN 2 PRECEDING AND 1 PRECEDING) AS a6\n            FROM a\n            ORDER BY a NULLS FIRST, b NULLS FIRST, c NULLS FIRST\n            \"\"\",\n            a=b,\n        )\n\n\ndef test_window_min_max_partition_by():\n    for func in [\"MIN\", \"MAX\"]:\n        a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50))\n        eq_sqlite(\n            f\"\"\"\n            SELECT a,b,\n                {func}(b) OVER (PARTITION BY c) AS a2,\n                {func}(b+a) OVER (PARTITION BY c,b) AS a3,\n                {func}(b+a) OVER (PARTITION BY b ORDER BY a\n                    ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS a4,\n                {func}(b+a) OVER (PARTITION BY b ORDER BY a DESC\n                    ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS a5,\n                {func}(b+a) OVER (PARTITION BY b ORDER BY a\n                    ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)\n                    AS a6\n            FROM a\n            ORDER BY a NULLS FIRST, b NULLS FIRST, c NULLS FIRST\n            \"\"\",\n            a=a,\n        )\n        b = make_rand_df(10, a=float, b=(int, 0), c=(str, 0))\n        eq_sqlite(\n            f\"\"\"\n            SELECT a,b,\n                {func}(b) OVER (PARTITION BY b ORDER BY a DESC\n                    ROWS BETWEEN 2 PRECEDING AND 1 PRECEDING) AS a6\n            FROM a\n            ORDER BY a NULLS FIRST, b NULLS FIRST, c NULLS FIRST\n            \"\"\",\n            a=b,\n        )\n\n\n# TODO: investigate source of window count deadlocks\n@skipif_dask_expr_enabled(\"Deadlocks with query planning enabled\")\ndef test_window_count():\n    for func in [\"COUNT\"]:\n        a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50))\n        eq_sqlite(\n            f\"\"\"\n            SELECT a,b,\n                {func}(b) OVER () AS a1,\n                {func}(b) OVER (PARTITION BY c) AS a2,\n                {func}(b+a) OVER (PARTITION BY c,b) AS a3,\n                {func}(b+a) OVER (PARTITION BY b ORDER BY a\n                    ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS a4,\n                {func}(b+a) OVER (PARTITION BY b ORDER BY a DESC\n                    ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS a5,\n                {func}(b+a) OVER (PARTITION BY b ORDER BY a\n                    ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)\n                    AS a6,\n                {func}(c) OVER () AS b1,\n                {func}(c) OVER (PARTITION BY c) AS b2,\n                {func}(c) OVER (PARTITION BY c,b) AS b3,\n                {func}(c) OVER (PARTITION BY b ORDER BY a\n                    ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS b4,\n                {func}(c) OVER (PARTITION BY b ORDER BY a DESC\n                    ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS b5,\n                {func}(c) OVER (PARTITION BY b ORDER BY a\n                    ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)\n                    AS b6\n            FROM a\n            ORDER BY a NULLS FIRST, b NULLS FIRST, c NULLS FIRST\n            \"\"\",\n            a=a,\n        )\n        # irregular windows\n        eq_sqlite(\n            f\"\"\"\n            SELECT a,b,\n                {func}(b) OVER (ORDER BY a DESC\n                    ROWS BETWEEN 2 PRECEDING AND 0 PRECEDING) AS a6,\n                {func}(b) OVER (PARTITION BY c ORDER BY a DESC\n                    ROWS BETWEEN 2 PRECEDING AND 0 PRECEDING) AS a9,\n                {func}(c) OVER (ORDER BY a DESC\n                    ROWS BETWEEN 2 PRECEDING AND 0 PRECEDING) AS b6,\n                {func}(c) OVER (PARTITION BY c ORDER BY a DESC\n                    ROWS BETWEEN 2 PRECEDING AND 0 PRECEDING) AS b9\n            FROM a\n            ORDER BY a NULLS FIRST, b NULLS FIRST, c NULLS FIRST\n            \"\"\",\n            a=a,\n        )\n\n\n# TODO: investigate source of window count deadlocks\n@skipif_dask_expr_enabled(\"Deadlocks with query planning enabled\")\ndef test_window_count_partition_by():\n    for func in [\"COUNT\"]:\n        a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50))\n        eq_sqlite(\n            f\"\"\"\n            SELECT a,b,\n                {func}(b) OVER (PARTITION BY c) AS a2,\n                {func}(b+a) OVER (PARTITION BY c,b) AS a3,\n                {func}(b+a) OVER (PARTITION BY b ORDER BY a\n                    ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS a4,\n                {func}(b+a) OVER (PARTITION BY b ORDER BY a DESC\n                    ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS a5,\n                {func}(b+a) OVER (PARTITION BY b ORDER BY a\n                    ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)\n                    AS a6,\n                {func}(c) OVER (PARTITION BY c) AS b2,\n                {func}(c) OVER (PARTITION BY c,b) AS b3,\n                {func}(c) OVER (PARTITION BY b ORDER BY a\n                    ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS b4,\n                {func}(c) OVER (PARTITION BY b ORDER BY a DESC\n                    ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS b5,\n                {func}(c) OVER (PARTITION BY b ORDER BY a\n                    ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)\n                    AS b6\n            FROM a\n            ORDER BY a NULLS FIRST, b NULLS FIRST, c NULLS FIRST\n            \"\"\",\n            a=a,\n        )\n        # irregular windows\n        eq_sqlite(\n            f\"\"\"\n            SELECT a,b,\n                {func}(b) OVER (PARTITION BY c ORDER BY a DESC\n                    ROWS BETWEEN 2 PRECEDING AND 0 PRECEDING) AS a9,\n                {func}(c) OVER (PARTITION BY c ORDER BY a DESC\n                    ROWS BETWEEN 2 PRECEDING AND 0 PRECEDING) AS b9\n            FROM a\n            ORDER BY a NULLS FIRST, b NULLS FIRST, c NULLS FIRST\n            \"\"\",\n            a=a,\n        )\n\n\ndef test_nested_query():\n    a = make_rand_df(100, a=float, b=(int, 50), c=(str, 50))\n    eq_sqlite(\n        \"\"\"\n        SELECT * FROM (\n        SELECT *,\n            ROW_NUMBER() OVER (PARTITION BY c ORDER BY b NULLS FIRST, a ASC NULLS LAST) AS r\n        FROM a)\n        WHERE r=1\n        ORDER BY a NULLS LAST, b NULLS LAST, c NULLS LAST\n        \"\"\",\n        a=a,\n    )\n\n\ndef test_union():\n    a = make_rand_df(30, b=(int, 10), c=(str, 10))\n    b = make_rand_df(80, b=(int, 50), c=(str, 50))\n    c = make_rand_df(100, b=(int, 50), c=(str, 50))\n    eq_sqlite(\n        \"\"\"\n        SELECT * FROM a\n            UNION SELECT * FROM b\n            UNION SELECT * FROM c\n        ORDER BY b NULLS FIRST, c NULLS FIRST\n        \"\"\",\n        a=a,\n        b=b,\n        c=c,\n    )\n    eq_sqlite(\n        \"\"\"\n        SELECT * FROM a\n            UNION ALL SELECT * FROM b\n            UNION ALL SELECT * FROM c\n        ORDER BY b NULLS FIRST, c NULLS FIRST\n        \"\"\",\n        a=a,\n        b=b,\n        c=c,\n    )\n\n\n@pytest.mark.xfail(\n    reason=\"'ANTI' joins not supported yet, see https://github.com/dask-contrib/dask-sql/issues/879\"\n)\ndef test_except():\n    a = make_rand_df(30, b=(int, 10), c=(str, 10))\n    b = make_rand_df(80, b=(int, 50), c=(str, 50))\n    c = make_rand_df(100, b=(int, 50), c=(str, 50))\n    eq_sqlite(\n        \"\"\"\n        SELECT * FROM c\n            EXCEPT SELECT * FROM b\n            EXCEPT SELECT * FROM c\n        \"\"\",\n        a=a,\n        b=b,\n        c=c,\n    )\n\n\n@pytest.mark.xfail(\n    reason=\"INTERSECT is not compliant with SQLite, see https://github.com/dask-contrib/dask-sql/issues/880\"\n)\ndef test_intersect():\n    a = make_rand_df(30, b=(int, 10), c=(str, 10))\n    b = make_rand_df(80, b=(int, 50), c=(str, 50))\n    c = make_rand_df(100, b=(int, 50), c=(str, 50))\n    eq_sqlite(\n        \"\"\"\n        SELECT * FROM c\n            INTERSECT SELECT * FROM b\n            INTERSECT SELECT * FROM c\n        \"\"\",\n        a=a,\n        b=b,\n        c=c,\n    )\n\n\ndef test_with():\n    a = make_rand_df(30, a=(int, 10), b=(str, 10))\n    b = make_rand_df(80, ax=(int, 10), bx=(str, 10))\n    eq_sqlite(\n        \"\"\"\n        WITH\n            aa AS (\n                SELECT a AS aa, b AS bb FROM a\n            ),\n            c AS (\n                SELECT aa-1 AS aa, bb FROM aa\n            )\n        SELECT * FROM c UNION SELECT * FROM b\n        ORDER BY aa NULLS FIRST, bb NULLS FIRST\n        \"\"\",\n        a=a,\n        b=b,\n    )\n\n\ndef test_integration_1():\n    a = make_rand_df(100, a=int, b=str, c=float, d=int, e=bool, f=str, g=str, h=float)\n    eq_sqlite(\n        \"\"\"\n        WITH\n            a1 AS (\n                SELECT a+1 AS a, b, c FROM a\n            ),\n            a2 AS (\n                SELECT a,MAX(b) AS b_max, AVG(c) AS c_avg FROM a GROUP BY a\n            ),\n            a3 AS (\n                SELECT d+2 AS d, f, g, h FROM a WHERE e\n            )\n        SELECT a1.a,b,c,b_max,c_avg,f,g,h FROM a1\n            INNER JOIN a2 ON a1.a=a2.a\n            LEFT JOIN a3 ON a1.a=a3.d\n        ORDER BY a1.a NULLS FIRST, b NULLS FIRST, c NULLS FIRST, f NULLS FIRST, g NULLS FIRST, h NULLS FIRST\n        \"\"\",\n        a=a,\n    )\n\n\n@pytest.mark.parametrize(\n    \"case_sensitive\",\n    [\n        False,\n        pytest.param(\n            True,\n            marks=pytest.mark.xfail(\n                reason=\"https://github.com/dask-contrib/dask-sql/issues/1092\"\n            ),\n        ),\n    ],\n)\ndef test_query_case_sensitivity(case_sensitive):\n    c = Context()\n    df = pd.DataFrame({\"id\": [0, 1], \"VAL\": [1, 2]})\n\n    c.create_table(\"test\", df)\n    q1 = \"select ID from test\"\n    q2 = \"select val from test\"\n    q3 = \"select Id, VAl from test\"\n    with dask.config.set({\"sql.identifier.case_sensitive\": case_sensitive}):\n        if case_sensitive:\n            with pytest.raises(ParsingException):\n                c.sql(q1)\n            with pytest.raises(ParsingException):\n                c.sql(q2)\n            with pytest.raises(ParsingException):\n                c.sql(q3)\n            result = c.sql(\"SELECT VAL from test\")\n            assert_eq(result, df[[\"VAL\"]])\n        else:\n            df.columns = df.columns.str.lower()\n            result = c.sql(q1)\n            assert_eq(result, df[[\"id\"]])\n            result = c.sql(q2)\n            assert_eq(result, df[[\"val\"]])\n            result = c.sql(q3)\n            assert_eq(result, df[[\"id\", \"val\"]])\n\n\ndef test_column_name_starting_with_number():\n    c = Context()\n    df = pd.DataFrame({\"a\": range(10), \"1b\": range(10)})\n    c.create_table(\"df\", df)\n\n    result = c.sql(\n        \"\"\"\n        SELECT \"1b\" AS x FROM df\n        \"\"\"\n    )\n    expected = pd.DataFrame({\"x\": range(10)})\n    assert_eq(result, expected)\n\n    result = c.sql(\n        \"\"\"\n        SELECT (CASE WHEN \"1b\"=1 THEN 0 END) AS x FROM df\n        \"\"\"\n    )\n    expected = pd.DataFrame(\n        {\"x\": [None, 0, None, None, None, None, None, None, None, None]}\n    )\n    assert_eq(result, expected)\n"
  },
  {
    "path": "tests/integration/test_complex.py",
    "content": "from dask.datasets import timeseries\n\n\ndef test_complex_query(c):\n    df = timeseries(freq=\"1d\").persist()\n    c.create_table(\"timeseries\", df)\n\n    result = c.sql(\n        \"\"\"\n        SELECT\n            lhs.name,\n            lhs.id,\n            lhs.x\n        FROM\n            timeseries AS lhs\n        JOIN\n            (\n                SELECT\n                    name AS max_name,\n                    MAX(x) AS max_x\n                FROM timeseries\n                GROUP BY name\n            ) AS rhs\n        ON\n            lhs.name = rhs.max_name AND\n            lhs.x = rhs.max_x\n    \"\"\"\n    ).compute()\n\n    assert len(result) > 0\n"
  },
  {
    "path": "tests/integration/test_create.py",
    "content": "import dask.dataframe as dd\nimport pandas as pd\nimport pytest\n\nimport dask_sql\nfrom tests.utils import assert_eq\n\n\n@pytest.mark.parametrize(\"gpu\", [False, pytest.param(True, marks=pytest.mark.gpu)])\ndef test_create_from_csv(c, df, temporary_data_file, gpu):\n    df.to_csv(temporary_data_file, index=False)\n\n    c.sql(\n        f\"\"\"\n        CREATE TABLE\n            new_table\n        WITH (\n            location = '{temporary_data_file}',\n            format = 'csv',\n            gpu = {gpu}\n        )\n    \"\"\"\n    )\n\n    result_df = c.sql(\n        \"\"\"\n        SELECT * FROM new_table\n    \"\"\"\n    )\n\n    assert_eq(result_df, df)\n\n\n@pytest.mark.parametrize(\n    \"gpu\",\n    [\n        False,\n        pytest.param(True, marks=pytest.mark.gpu),\n    ],\n)\ndef test_cluster_memory(client, c, df, gpu):\n    client.publish_dataset(df=dd.from_pandas(df, npartitions=1))\n\n    c.sql(\n        f\"\"\"\n        CREATE TABLE\n            new_table\n        WITH (\n            location = 'df',\n            format = 'memory',\n            gpu = {gpu}\n        )\n    \"\"\"\n    )\n\n    return_df = c.sql(\n        \"\"\"\n        SELECT * FROM new_table\n    \"\"\"\n    )\n\n    assert_eq(df, return_df)\n\n    client.unpublish_dataset(\"df\")\n\n\n@pytest.mark.parametrize(\"gpu\", [False, pytest.param(True, marks=pytest.mark.gpu)])\ndef test_create_from_csv_persist(c, df, temporary_data_file, gpu):\n    df.to_csv(temporary_data_file, index=False)\n\n    c.sql(\n        f\"\"\"\n        CREATE TABLE\n            new_table\n        WITH (\n            location = '{temporary_data_file}',\n            format = 'csv',\n            persist = True,\n            gpu = {gpu}\n        )\n    \"\"\"\n    )\n\n    return_df = c.sql(\n        \"\"\"\n        SELECT * FROM new_table\n    \"\"\"\n    )\n\n    assert_eq(df, return_df)\n\n\ndef test_wrong_create(c):\n    with pytest.raises(AttributeError):\n        c.sql(\n            \"\"\"\n            CREATE TABLE\n                new_table\n            WITH (\n                format = 'csv'\n            )\n        \"\"\"\n        )\n\n    with pytest.raises(AttributeError):\n        c.sql(\n            \"\"\"\n            CREATE TABLE\n                new_table\n            WITH (\n                format = 'strange',\n                location = 'some/path'\n            )\n        \"\"\"\n        )\n\n\ndef test_create_from_query(c, df):\n    with pytest.raises(RuntimeError):\n        c.sql(\n            \"\"\"\n            CREATE OR REPLACE TABLE\n                other.new_table\n            AS (\n                SELECT * FROM df\n            )\n        \"\"\"\n        )\n\n    c.sql(\n        \"\"\"\n        CREATE OR REPLACE TABLE\n            new_table\n        AS (\n            SELECT * FROM df\n        )\n    \"\"\"\n    )\n\n    return_df = c.sql(\n        \"\"\"\n        SELECT * FROM new_table\n    \"\"\"\n    )\n\n    assert_eq(df, return_df)\n\n    with pytest.raises(RuntimeError):\n        c.sql(\n            \"\"\"\n            CREATE OR REPLACE VIEW\n                other.new_table\n            AS (\n                SELECT * FROM df\n            )\n        \"\"\"\n        )\n\n    c.sql(\n        \"\"\"\n        CREATE OR REPLACE VIEW\n            new_table\n        AS (\n            SELECT * FROM df\n        )\n    \"\"\"\n    )\n\n    return_df = c.sql(\n        \"\"\"\n        SELECT * FROM new_table\n    \"\"\"\n    )\n\n    assert_eq(df, return_df)\n\n\n@pytest.mark.parametrize(\n    \"gpu\",\n    [\n        False,\n        pytest.param(\n            True,\n            marks=pytest.mark.gpu,\n        ),\n    ],\n)\ndef test_view_table_persist(c, temporary_data_file, df, gpu):\n    df.to_csv(temporary_data_file, index=False)\n    c.sql(\n        f\"\"\"\n        CREATE TABLE\n            new_table\n        WITH (\n            location = '{temporary_data_file}',\n            format = 'csv',\n            gpu = {gpu}\n        )\n    \"\"\"\n    )\n\n    # Views should change, when the original data changes\n    # Tables should not change, when the original data changes\n    c.sql(\n        \"\"\"\n        CREATE VIEW\n            count_view\n        AS (\n            SELECT COUNT(*) AS c FROM new_table\n        )\n    \"\"\"\n    )\n    c.sql(\n        \"\"\"\n        CREATE TABLE\n            count_table\n        AS (\n            SELECT COUNT(*) AS c FROM new_table\n        )\n    \"\"\"\n    )\n\n    from_view = c.sql(\"SELECT c FROM count_view\")\n    from_table = c.sql(\"SELECT c FROM count_table\")\n\n    assert_eq(from_view, pd.DataFrame({\"c\": [700]}))\n    assert_eq(from_table, pd.DataFrame({\"c\": [700]}))\n\n    df.iloc[:10].to_csv(temporary_data_file, index=False)\n\n    from_view = c.sql(\"SELECT c FROM count_view\")\n    from_table = c.sql(\"SELECT c FROM count_table\")\n\n    assert_eq(from_view, pd.DataFrame({\"c\": [10]}))\n    assert_eq(from_table, pd.DataFrame({\"c\": [700]}))\n\n\ndef test_replace_and_error(c, temporary_data_file, df):\n    c.sql(\n        \"\"\"\n        CREATE TABLE\n            new_table\n        AS (\n            SELECT 1 AS a\n        )\n    \"\"\"\n    )\n\n    assert_eq(\n        c.sql(\"SELECT a FROM new_table\"),\n        pd.DataFrame({\"a\": [1]}),\n        check_dtype=False,\n    )\n\n    with pytest.raises(RuntimeError):\n        c.sql(\n            \"\"\"\n            CREATE TABLE\n                new_table\n            AS (\n                SELECT 1\n            )\n        \"\"\"\n        )\n\n    c.sql(\n        \"\"\"\n        CREATE TABLE IF NOT EXISTS\n            new_table\n        AS (\n            SELECT 2 AS a\n        )\n    \"\"\"\n    )\n\n    assert_eq(\n        c.sql(\"SELECT a FROM new_table\"),\n        pd.DataFrame({\"a\": [1]}),\n        check_dtype=False,\n    )\n\n    c.sql(\n        \"\"\"\n        CREATE OR REPLACE TABLE\n            new_table\n        AS (\n            SELECT 2 AS a\n        )\n    \"\"\"\n    )\n\n    assert_eq(\n        c.sql(\"SELECT a FROM new_table\"),\n        pd.DataFrame({\"a\": [2]}),\n        check_dtype=False,\n    )\n\n    c.sql(\"DROP TABLE new_table\")\n\n    with pytest.raises(dask_sql.utils.ParsingException):\n        c.sql(\"SELECT a FROM new_table\")\n\n    c.sql(\n        \"\"\"\n        CREATE TABLE IF NOT EXISTS\n            new_table\n        AS (\n            SELECT 3 AS a\n        )\n    \"\"\"\n    )\n\n    assert_eq(\n        c.sql(\"SELECT a FROM new_table\"),\n        pd.DataFrame({\"a\": [3]}),\n        check_dtype=False,\n    )\n\n    df.to_csv(temporary_data_file, index=False)\n    with pytest.raises(RuntimeError):\n        c.sql(\n            f\"\"\"\n            CREATE TABLE\n                new_table\n            WITH (\n                location = '{temporary_data_file}',\n                format = 'csv'\n            )\n        \"\"\"\n        )\n\n    c.sql(\n        f\"\"\"\n        CREATE TABLE IF NOT EXISTS\n            new_table\n        WITH (\n            location = '{temporary_data_file}',\n            format = 'csv'\n        )\n    \"\"\"\n    )\n\n    assert_eq(\n        c.sql(\"SELECT a FROM new_table\"),\n        pd.DataFrame({\"a\": [3]}),\n        check_dtype=False,\n    )\n\n    c.sql(\n        f\"\"\"\n        CREATE OR REPLACE TABLE\n            new_table\n        WITH (\n            location = '{temporary_data_file}',\n            format = 'csv'\n        )\n    \"\"\"\n    )\n\n    result_df = c.sql(\"SELECT * FROM new_table\")\n\n    assert_eq(result_df, df)\n\n\ndef test_drop(c):\n    with pytest.raises(RuntimeError):\n        c.sql(\"DROP TABLE new_table\")\n\n    c.sql(\"DROP TABLE IF EXISTS new_table\")\n\n    c.sql(\n        \"\"\"\n        CREATE TABLE\n            new_table\n        AS (\n            SELECT 1 AS a\n        )\n    \"\"\"\n    )\n\n    with pytest.raises(RuntimeError):\n        c.sql(\"DROP TABLE other.new_table\")\n\n    c.sql(\"DROP TABLE IF EXISTS new_table\")\n\n    with pytest.raises(dask_sql.utils.ParsingException):\n        c.sql(\"SELECT a FROM new_table\")\n\n\ndef test_create_gpu_error(c, df, temporary_data_file):\n    try:\n        import cudf\n    except ImportError:\n        cudf = None\n\n    if cudf is not None:\n        pytest.skip(\"GPU-related import errors only need to be checked on CPU\")\n\n    with pytest.raises(ModuleNotFoundError):\n        c.create_table(\"new_table\", df, gpu=True)\n\n    with pytest.raises(ModuleNotFoundError):\n        c.create_table(\"new_table\", dd.from_pandas(df, npartitions=2), gpu=True)\n\n    df.to_csv(temporary_data_file, index=False)\n\n    with pytest.raises(ModuleNotFoundError):\n        c.sql(\n            f\"\"\"\n            CREATE TABLE\n                new_table\n            WITH (\n                location = '{temporary_data_file}',\n                format = 'csv',\n                gpu = True\n            )\n        \"\"\"\n        )\n"
  },
  {
    "path": "tests/integration/test_distributeby.py",
    "content": "import dask.dataframe as dd\nimport pandas as pd\nimport pytest\n\n\n@pytest.mark.parametrize(\"gpu\", [False, pytest.param(True, marks=pytest.mark.gpu)])\ndef test_distribute_by(c, gpu):\n    df = pd.DataFrame({\"id\": [0, 1, 2, 1, 2, 3], \"val\": [0, 1, 2, 1, 2, 3]})\n    ddf = dd.from_pandas(df, npartitions=2)\n\n    c.create_table(\"test\", ddf, gpu=gpu)\n    partitioned_ddf = c.sql(\n        \"\"\"\n    SELECT\n    id\n    FROM test\n    DISTRIBUTE BY id\n    \"\"\"\n    )\n    part_0_ids = partitioned_ddf.get_partition(0).compute().id.unique()\n    part_1_ids = partitioned_ddf.get_partition(1).compute().id.unique()\n\n    if gpu:\n        part_0_ids = part_0_ids.to_pandas()\n        part_1_ids = part_1_ids.to_pandas()\n\n    assert bool(set(part_0_ids) & set(part_1_ids)) is False\n"
  },
  {
    "path": "tests/integration/test_explain.py",
    "content": "import dask.dataframe as dd\nimport pandas as pd\nimport pytest\n\nfrom dask_sql import Statistics\n\n\n@pytest.mark.parametrize(\"gpu\", [False, pytest.param(True, marks=pytest.mark.gpu)])\ndef test_sql_query_explain(c, gpu):\n    df = dd.from_pandas(pd.DataFrame({\"a\": [1, 2, 3]}), npartitions=1)\n    c.create_table(\"df\", df, gpu=gpu)\n\n    sql_string = c.sql(\"EXPLAIN SELECT * FROM df\")\n\n    assert sql_string.startswith(\"Projection: df.a\\n\")\n\n    sql_string = c.sql(\n        \"EXPLAIN SELECT MIN(a) AS a_min FROM other_df GROUP BY a\",\n        dataframes={\"other_df\": df},\n        gpu=gpu,\n    )\n    assert sql_string.startswith(\"Projection: MIN(other_df.a) AS a_min\\n\")\n    assert \"Aggregate: groupBy=[[other_df.a]], aggr=[[MIN(other_df.a)]]\" in sql_string\n\n\n@pytest.mark.xfail(reason=\"Need to add statistics to Rust optimizer\")\n@pytest.mark.parametrize(\"gpu\", [False, pytest.param(True, marks=pytest.mark.gpu)])\ndef test_statistics_explain(c, gpu):\n    df = dd.from_pandas(pd.DataFrame({\"a\": [1, 2, 3]}), npartitions=1)\n    c.create_table(\"df\", df, statistics=Statistics(row_count=1337), gpu=gpu)\n\n    sql_string = c.explain(\"SELECT * FROM df\")\n\n    assert sql_string.startswith(\n        \"DaskTableScan(table=[[root, df]]): rowcount = 1337.0, cumulative cost = {1337.0 rows, 1338.0 cpu, 0.0 io}, id = \"\n    )\n"
  },
  {
    "path": "tests/integration/test_filter.py",
    "content": "import dask\nimport dask.dataframe as dd\nimport pandas as pd\nimport pytest\nfrom dask.utils_test import hlg_layer\nfrom packaging.version import parse as parseVersion\n\nfrom tests.utils import assert_eq, skipif_dask_expr_enabled\n\nDASK_GT_2022_4_2 = parseVersion(dask.__version__) >= parseVersion(\"2022.4.2\")\n\n\ndef test_filter(c, df):\n    return_df = c.sql(\"SELECT * FROM df WHERE a < 2\")\n\n    expected_df = df[df[\"a\"] < 2]\n    assert_eq(return_df, expected_df)\n\n\ndef test_filter_scalar(c, df):\n    return_df = c.sql(\"SELECT * FROM df WHERE True\")\n\n    expected_df = df\n    assert_eq(return_df, expected_df)\n\n    return_df = c.sql(\"SELECT * FROM df WHERE False\")\n\n    expected_df = df.head(0)\n    assert_eq(return_df, expected_df, check_index_type=False)\n\n    return_df = c.sql(\"SELECT * FROM df WHERE (1 = 1)\")\n\n    expected_df = df\n    assert_eq(return_df, expected_df)\n\n    return_df = c.sql(\"SELECT * FROM df WHERE (1 = 0)\")\n\n    expected_df = df.head(0)\n    assert_eq(return_df, expected_df, check_index_type=False)\n\n\ndef test_filter_complicated(c, df):\n    return_df = c.sql(\"SELECT * FROM df WHERE a < 3 AND (b > 1 AND b < 3)\")\n\n    expected_df = df[((df[\"a\"] < 3) & ((df[\"b\"] > 1) & (df[\"b\"] < 3)))]\n    assert_eq(\n        return_df,\n        expected_df,\n    )\n\n\ndef test_filter_with_nan(c):\n    return_df = c.sql(\"SELECT * FROM user_table_nan WHERE c = 3\")\n    expected_df = pd.DataFrame({\"c\": [3]}, dtype=\"Int8\")\n\n    assert_eq(\n        return_df,\n        expected_df,\n    )\n\n\ndef test_string_filter(c, string_table):\n    return_df = c.sql(\"SELECT * FROM string_table WHERE a = 'a normal string'\")\n\n    assert_eq(\n        return_df,\n        string_table.head(1),\n    )\n    # Condition needs to specifically check on `M` since this the literal `M`\n    # was getting parsed as a datetime dtype\n    return_df = c.sql(\"SELECT * from string_table WHERE a = 'M'\")\n    expected_df = string_table[string_table[\"a\"] == \"M\"]\n    assert_eq(return_df, expected_df)\n\n\n@pytest.mark.parametrize(\n    \"input_table\",\n    [\n        \"datetime_table\",\n        pytest.param(\n            \"gpu_datetime_table\",\n            marks=(pytest.mark.gpu),\n        ),\n    ],\n)\ndef test_filter_cast_date(c, input_table, request):\n    datetime_table = request.getfixturevalue(input_table)\n    return_df = c.sql(\n        f\"\"\"\n        SELECT * FROM {input_table} WHERE\n            CAST(timezone AS DATE) > DATE '2014-08-01'\n        \"\"\"\n    )\n    expected_df = datetime_table[\n        datetime_table[\"timezone\"].dt.tz_localize(None).dt.floor(\"D\").astype(\"<M8[ns]\")\n        > pd.Timestamp(\"2014-08-01\")\n    ]\n    assert_eq(return_df, expected_df)\n\n\n@pytest.mark.parametrize(\n    \"input_table\",\n    [\n        \"datetime_table\",\n        pytest.param(\n            \"gpu_datetime_table\",\n            marks=(pytest.mark.gpu),\n        ),\n    ],\n)\n@pytest.mark.xfail(\n    reason=\"Need support for non-UTC timezoned literals, see https://github.com/dask-contrib/dask-sql/issues/1193\"\n)\ndef test_filter_cast_timestamp(c, input_table, request):\n    datetime_table = request.getfixturevalue(input_table)\n    return_df = c.sql(\n        f\"\"\"\n        SELECT * FROM {input_table} WHERE\n            CAST(timezone AS TIMESTAMP) >= TIMESTAMP '2014-08-01 23:00:00+00'\n        \"\"\"\n    )\n\n    expected_df = datetime_table[\n        datetime_table[\"timezone\"].astype(\"<M8[ns]\")\n        >= pd.Timestamp(\"2014-08-01 23:00:00\")\n    ]\n    assert_eq(return_df, expected_df)\n\n\ndef test_filter_year(c):\n    df = pd.DataFrame({\"year\": [2015, 2016], \"month\": [2, 3], \"day\": [4, 5]})\n    df[\"dt\"] = pd.to_datetime(df)\n\n    c.create_table(\"datetime_test\", df)\n\n    return_df = c.sql(\"select * from datetime_test where year(dt) < 2016\")\n    expected_df = df[df[\"year\"] < 2016]\n\n    assert_eq(expected_df, return_df)\n\n\n@pytest.mark.parametrize(\n    \"query,df_func,filters\",\n    [\n        (\n            \"SELECT * FROM parquet_ddf WHERE b < 10\",\n            lambda x: x[x[\"b\"] < 10],\n            [[(\"b\", \"<\", 10)]],\n        ),\n        (\n            \"SELECT * FROM parquet_ddf WHERE a < 3 AND (b > 1 AND b < 5)\",\n            lambda x: x[(x[\"a\"] < 3) & ((x[\"b\"] > 1) & (x[\"b\"] < 5))],\n            [[(\"a\", \"<\", 3), (\"b\", \">\", 1), (\"b\", \"<\", 5)]],\n        ),\n        (\n            \"SELECT * FROM parquet_ddf WHERE (b > 5 AND b < 10) OR a = 1\",\n            lambda x: x[((x[\"b\"] > 5) & (x[\"b\"] < 10)) | (x[\"a\"] == 1)],\n            [[(\"b\", \">\", 5), (\"b\", \"<\", 10)], [(\"a\", \"==\", 1)]],\n        ),\n        pytest.param(\n            \"SELECT * FROM parquet_ddf WHERE b IN (1, 6)\",\n            lambda x: x[(x[\"b\"] == 1) | (x[\"b\"] == 6)],\n            [[(\"b\", \"==\", 1)], [(\"b\", \"==\", 6)]],\n        ),\n        pytest.param(\n            \"SELECT * FROM parquet_ddf WHERE b IN (1, 3, 5, 6)\",\n            lambda x: x[x[\"b\"].isin([1, 3, 5, 6])],\n            [[(\"b\", \"in\", (1, 3, 5, 6))]],\n        ),\n        pytest.param(\n            \"SELECT * FROM parquet_ddf WHERE c IN ('A', 'B', 'C', 'D')\",\n            lambda x: x[x[\"c\"].isin([\"A\", \"B\", \"C\", \"D\"])],\n            [[(\"c\", \"in\", (\"A\", \"B\", \"C\", \"D\"))]],\n        ),\n        pytest.param(\n            \"SELECT * FROM parquet_ddf WHERE b NOT IN (1, 6)\",\n            lambda x: x[(x[\"b\"] != 1) & (x[\"b\"] != 6)],\n            [[(\"b\", \"!=\", 1), (\"b\", \"!=\", 6)]],\n        ),\n        pytest.param(\n            \"SELECT * FROM parquet_ddf WHERE b NOT IN (1, 3, 5, 6)\",\n            lambda x: x[~x[\"b\"].isin([1, 3, 5, 6])],\n            [[(\"b\", \"not in\", (1, 3, 5, 6))]],\n        ),\n        (\n            \"SELECT a FROM parquet_ddf WHERE (b > 5 AND b < 10) OR a = 1\",\n            lambda x: x[((x[\"b\"] > 5) & (x[\"b\"] < 10)) | (x[\"a\"] == 1)][[\"a\"]],\n            [[(\"b\", \">\", 5), (\"b\", \"<\", 10)], [(\"a\", \"==\", 1)]],\n        ),\n        (\n            # Original filters NOT in disjunctive normal form\n            \"SELECT a FROM parquet_ddf WHERE (parquet_ddf.b > 3 AND parquet_ddf.b < 10 OR parquet_ddf.a = 1) AND (parquet_ddf.c = 'A')\",\n            lambda x: x[\n                ((x[\"b\"] > 3) & (x[\"b\"] < 10) | (x[\"a\"] == 1)) & (x[\"c\"] == \"A\")\n            ][[\"a\"]],\n            [\n                [(\"c\", \"==\", \"A\"), (\"b\", \">\", 3), (\"b\", \"<\", 10)],\n                [(\"a\", \"==\", 1), (\"c\", \"==\", \"A\")],\n            ],\n        ),\n        (\n            # The predicate-pushdown optimization will be skipped here,\n            # because datetime accessors are not supported. However,\n            # the query should still succeed.\n            \"SELECT * FROM parquet_ddf WHERE year(d) < 2015\",\n            lambda x: x[x[\"d\"].dt.year < 2015],\n            None,\n        ),\n    ],\n)\n@skipif_dask_expr_enabled()\ndef test_predicate_pushdown(c, parquet_ddf, query, df_func, filters):\n\n    # Check for predicate pushdown.\n    # We can use the `hlg_layer` utility to make sure the\n    # `filters` field has been populated in `creation_info`\n    return_df = c.sql(query)\n    expect_filters = filters\n    got_filters = hlg_layer(return_df.dask, \"read-parquet\").creation_info[\"kwargs\"][\n        \"filters\"\n    ]\n    if expect_filters:\n        got_filters = frozenset(frozenset(v) for v in got_filters)\n        expect_filters = frozenset(frozenset(v) for v in filters)\n\n    assert got_filters == expect_filters\n\n    # Check computed result is correct\n    df = parquet_ddf\n    expected_df = df_func(df)\n\n    # divisions aren't equal for older dask versions\n    assert_eq(\n        return_df, expected_df, check_index=False, check_divisions=DASK_GT_2022_4_2\n    )\n\n\ndef test_filtered_csv(tmpdir, c):\n    # Predicate pushdown is NOT supported for CSV data.\n    # This test just checks that the \"attempted\"\n    # predicate-pushdown logic does not lead to\n    # any unexpected errors\n\n    # Write simple csv dataset\n    df = pd.DataFrame(\n        {\n            \"a\": [1, 2, 3] * 5,\n            \"b\": range(15),\n            \"c\": [\"A\"] * 15,\n        },\n    )\n    dd.from_pandas(df, npartitions=3).to_csv(tmpdir + \"/*.csv\", index=False)\n\n    # Read back with dask and apply WHERE query\n    csv_ddf = dd.read_csv(tmpdir + \"/*.csv\")\n    try:\n        c.create_table(\"my_csv_table\", csv_ddf)\n        return_df = c.sql(\"SELECT * FROM my_csv_table WHERE b < 10\")\n    finally:\n        c.drop_table(\"my_csv_table\")\n\n    # Check computed result is correct\n    df = csv_ddf\n    expected_df = df[df[\"b\"] < 10]\n\n    assert_eq(return_df, expected_df)\n\n\n@pytest.mark.parametrize(\"gpu\", [False, pytest.param(True, marks=pytest.mark.gpu)])\ndef test_filter_decimal(c, gpu):\n    df = pd.DataFrame(\n        {\n            \"a\": [304.5, 35.305, 9.043, 102.424, 53.34],\n            \"b\": [2.2, 82.4, 42, 76.9, 54.4],\n            \"c\": [1, 2, 2, 5, 9],\n        }\n    )\n    c.create_table(\"df\", df, gpu=gpu)\n\n    result_df = c.sql(\n        \"\"\"\n        SELECT\n            c\n        FROM\n            df\n        WHERE\n            CAST(a AS DECIMAL) < CAST(b AS DECIMAL)\n        \"\"\"\n    )\n\n    expected_df = df.loc[df.a < df.b][[\"c\"]]\n\n    assert_eq(result_df, expected_df)\n\n    result_df = c.sql(\n        \"\"\"\n        SELECT\n            CAST(b AS DECIMAL) as b\n        FROM\n            df\n        WHERE\n            CAST(a AS DECIMAL) < DECIMAL '100.2'\n        \"\"\"\n    )\n\n    # decimal precision doesn't match up with pandas floats\n    if gpu:\n        result_df[\"b\"] = result_df[\"b\"].astype(\"float64\")\n\n    expected_df = df.loc[df.a < 100.2][[\"b\"]]\n\n    assert_eq(result_df, expected_df, check_index=False)\n    c.drop_table(\"df\")\n\n\n@skipif_dask_expr_enabled()\ndef test_predicate_pushdown_isna(tmpdir):\n    from dask_sql.context import Context\n\n    c = Context()\n\n    path = str(tmpdir)\n    dd.from_pandas(\n        pd.DataFrame(\n            {\n                \"a\": [1, 2, None] * 5,\n                \"b\": range(15),\n                \"index\": range(15),\n            }\n        ),\n        npartitions=3,\n    ).to_parquet(path + \"/df1\")\n    df1 = dd.read_parquet(path + \"/df1\", index=\"index\")\n    c.create_table(\"df1\", df1)\n\n    dd.from_pandas(\n        pd.DataFrame(\n            {\n                \"a\": [None, 2, 3] * 5,\n                \"b\": range(15),\n                \"index\": range(15),\n            },\n        ),\n        npartitions=3,\n    ).to_parquet(path + \"/df2\")\n    df2 = dd.read_parquet(path + \"/df2\", index=\"index\")\n    c.create_table(\"df2\", df2)\n\n    return_df = c.sql(\"SELECT df1.a FROM df1, df2 WHERE df1.a = df2.a\")\n\n    # Check for predicate pushdown\n    filters = [[(\"a\", \"is not\", None)]]\n    got_filters = hlg_layer(return_df.dask, \"read-parquet\").creation_info[\"kwargs\"][\n        \"filters\"\n    ]\n\n    got_filters = frozenset(frozenset(v) for v in got_filters)\n    expect_filters = frozenset(frozenset(v) for v in filters)\n\n    assert got_filters == expect_filters\n    assert all(return_df.compute() == 2)\n    assert len(return_df) == 25\n"
  },
  {
    "path": "tests/integration/test_fugue.py",
    "content": "import dask.dataframe as dd\nimport pandas as pd\nimport pytest\n\nfrom dask_sql import Context\nfrom tests.utils import assert_eq\n\nfugue_sql = pytest.importorskip(\"fugue_sql\")\n\nfrom dask_sql.integrations.fugue import fsql_dask  # noqa: E402\n\n\ndef test_fugue_workflow(client):\n    dag = fugue_sql.FugueSQLWorkflow()\n    df = dag.df([[0, \"hello\"], [1, \"world\"]], \"a:int64,b:str\")\n    dag(\"SELECT * FROM df WHERE a > 0 YIELD DATAFRAME AS result\")\n\n    result = dag.run(\"dask\")\n    return_df = result[\"result\"].as_pandas()\n    assert_eq(return_df, pd.DataFrame({\"a\": [1], \"b\": [\"world\"]}))\n\n    result = dag.run(client)\n    return_df = result[\"result\"].as_pandas()\n    assert_eq(return_df, pd.DataFrame({\"a\": [1], \"b\": [\"world\"]}))\n\n\ndef test_fugue_fsql(client):\n    pdf = pd.DataFrame([[0, \"hello\"], [1, \"world\"]], columns=[\"a\", \"b\"])\n    dag = fugue_sql.fsql(\n        \"SELECT * FROM df WHERE a > 0 YIELD DATAFRAME AS result\",\n        df=pdf,\n    )\n\n    result = dag.run(\"dask\")\n    return_df = result[\"result\"].as_pandas()\n    assert_eq(return_df, pd.DataFrame({\"a\": [1], \"b\": [\"world\"]}))\n\n    result = dag.run(client)\n    return_df = result[\"result\"].as_pandas()\n    assert_eq(return_df, pd.DataFrame({\"a\": [1], \"b\": [\"world\"]}))\n\n\n@pytest.mark.flaky(reruns=4, condition=\"sys.version_info < (3, 10)\")\ndef test_dask_fsql(client):\n    def assert_fsql(df: pd.DataFrame) -> None:\n        assert_eq(df, pd.DataFrame({\"a\": [1]}))\n\n    # the simplest case: the SQL does not use any input and does not generate output\n    fsql_dask(\n        \"\"\"\n    CREATE [[0],[1]] SCHEMA a:long\n    SELECT * WHERE a>0\n    OUTPUT USING assert_fsql\n    \"\"\"\n    )\n\n    # it can directly use the dataframes inside dask-sql Context\n    c = Context()\n    c.create_table(\n        \"df\", dd.from_pandas(pd.DataFrame([[0], [1]], columns=[\"a\"]), npartitions=2)\n    )\n\n    fsql_dask(\n        \"\"\"\n    SELECT * FROM df WHERE a>0\n    OUTPUT USING assert_fsql\n    \"\"\",\n        c,\n    )\n\n    # for dataframes with name, they can register back to the Context (register=True)\n    # the return of fsql is the dict of all dask dataframes with explicit names\n    result = fsql_dask(\n        \"\"\"\n    x=SELECT * FROM df WHERE a>0\n    OUTPUT USING assert_fsql\n    \"\"\",\n        c,\n        register=True,\n    )\n    assert isinstance(result[\"x\"], dd.DataFrame)\n    assert \"x\" in c.schema[c.schema_name].tables\n\n    # integration test with fugue transformer extension\n    c = Context()\n    c.create_table(\n        \"df1\",\n        dd.from_pandas(\n            pd.DataFrame([[0, 1], [1, 2]], columns=[\"a\", \"b\"]), npartitions=2\n        ),\n    )\n    c.create_table(\n        \"df2\",\n        dd.from_pandas(\n            pd.DataFrame([[1, 2], [3, 4], [-4, 5]], columns=[\"a\", \"b\"]), npartitions=2\n        ),\n    )\n\n    # schema: *\n    def cumsum(df: pd.DataFrame) -> pd.DataFrame:\n        return df.cumsum()\n\n    fsql_dask(\n        \"\"\"\n    data = SELECT * FROM df1 WHERE a>0 UNION ALL SELECT * FROM df2 WHERE a>0 PERSIST\n    result1 = TRANSFORM data PREPARTITION BY a PRESORT b USING cumsum\n    result2 = TRANSFORM data PREPARTITION BY b PRESORT a USING cumsum\n    PRINT result1, result2\n    \"\"\",\n        c,\n        register=True,\n    )\n    assert \"result1\" in c.schema[c.schema_name].tables\n    assert \"result2\" in c.schema[c.schema_name].tables\n"
  },
  {
    "path": "tests/integration/test_function.py",
    "content": "import itertools\nimport operator\nimport sys\n\nimport dask.dataframe as dd\nimport numpy as np\nimport pytest\n\nfrom dask_sql.utils import ParsingException\nfrom tests.utils import assert_eq\n\n\ndef test_custom_function(c, df):\n    def f(x):\n        return x**2\n\n    c.register_function(f, \"f\", [(\"x\", np.float64)], np.float64)\n\n    return_df = c.sql(\"SELECT F(a) AS a FROM df\")\n\n    assert_eq(return_df, df[[\"a\"]] ** 2)\n\n\ndef test_custom_function_row(c, df):\n    def f(row):\n        return row[\"x\"] ** 2\n\n    c.register_function(f, \"f\", [(\"x\", np.float64)], np.float64, row_udf=True)\n\n    return_df = c.sql(\"SELECT F(a) AS a FROM df\")\n\n    assert_eq(return_df, df[[\"a\"]] ** 2)\n\n\n@pytest.mark.parametrize(\"colnames\", list(itertools.combinations([\"a\", \"b\", \"c\"], 2)))\ndef test_custom_function_any_colnames(colnames, df_wide, c):\n    # a third column is needed\n\n    def f(row):\n        return row[\"x\"] + row[\"y\"]\n\n    colname_x, colname_y = colnames\n    c.register_function(\n        f, \"f\", [(\"x\", np.int64), (\"y\", np.int64)], np.int64, row_udf=True\n    )\n\n    return_df = c.sql(f\"SELECT F({colname_x},{colname_y}) FROM df_wide\")\n\n    expect = df_wide[colname_x] + df_wide[colname_y]\n    got = return_df.iloc[:, 0]\n\n    assert_eq(expect, got, check_names=False)\n\n\n@pytest.mark.parametrize(\n    \"retty\",\n    [np.float64, np.float32, np.int64, np.int32, np.int16, np.int8, np.bool_],\n)\ndef test_custom_function_row_return_types(c, df, retty):\n    def f(row):\n        return row[\"x\"] ** 2\n\n    c.register_function(f, \"f\", [(\"x\", np.float64)], retty, row_udf=True)\n\n    return_df = c.sql(\"SELECT F(a) AS a FROM df\")\n\n    assert_eq(return_df, (df[[\"a\"]] ** 2).astype(retty))\n\n\n# Test row UDFs with one arg\n@pytest.mark.parametrize(\"k\", [1, 1.5, True])\n@pytest.mark.parametrize(\n    \"op\", [operator.add, operator.sub, operator.mul, operator.truediv]\n)\n@pytest.mark.parametrize(\"retty\", [np.int64, np.float64, np.bool_])\ndef test_custom_function_row_args(c, df, k, op, retty):\n    const_type = np.dtype(type(k)).type\n\n    if sys.platform == \"win32\" and const_type == np.int32:\n        const_type = np.int64\n\n    def f(row, k):\n        return op(row[\"a\"], k)\n\n    c.register_function(\n        f, \"f\", [(\"a\", np.float64), (\"k\", const_type)], retty, row_udf=True\n    )\n\n    return_df = c.sql(f\"SELECT F(a, {k}) as a from df\")\n    expected_df = op(df[[\"a\"]], k).astype(retty)\n\n    assert_eq(return_df, expected_df)\n\n\n# Test row UDFs with two args\n@pytest.mark.parametrize(\"k2\", [1, 1.5, True])\n@pytest.mark.parametrize(\"k1\", [1, 1.5, True])\n@pytest.mark.parametrize(\n    \"op\", [operator.add, operator.sub, operator.mul, operator.truediv]\n)\n@pytest.mark.parametrize(\"retty\", [np.int64, np.float64, np.bool_])\ndef test_custom_function_row_two_args(c, df, k1, k2, op, retty):\n    const_type_k1 = np.dtype(type(k1)).type\n    const_type_k2 = np.dtype(type(k2)).type\n\n    if sys.platform == \"win32\":\n        if const_type_k1 == np.int32:\n            const_type_k1 = np.int64\n        if const_type_k2 == np.int32:\n            const_type_k2 = np.int64\n\n    def f(row, k1, k2):\n        x = op(row[\"a\"], k1)\n        y = op(x, k2)\n\n        return y\n\n    c.register_function(\n        f,\n        \"f\",\n        [(\"a\", np.float64), (\"k1\", const_type_k1), (\"k2\", const_type_k2)],\n        retty,\n        row_udf=True,\n    )\n\n    return_df = c.sql(f\"SELECT F(a, {k1}, {k2}) as a from df\")\n    expected_df = op(op(df[[\"a\"]], k1), k2).astype(retty)\n\n    assert_eq(return_df, expected_df)\n\n\ndef test_multiple_definitions(c, df_simple):\n    def f(x):\n        return x**2\n\n    c.register_function(f, \"f\", [(\"x\", np.float64)], np.float64)\n    c.register_function(f, \"f\", [(\"x\", np.int64)], np.int64)\n\n    return_df = c.sql(\n        \"\"\"\n        SELECT F(a) AS a, f(b) AS b\n        FROM df_simple\n        \"\"\"\n    )\n    expected_df = df_simple[[\"a\", \"b\"]] ** 2\n\n    assert_eq(return_df, expected_df)\n\n    def f(x):\n        return x**3\n\n    c.register_function(f, \"f\", [(\"x\", np.float64)], np.float64, replace=True)\n    c.register_function(f, \"f\", [(\"x\", np.int64)], np.int64)\n\n    return_df = c.sql(\n        \"\"\"\n        SELECT F(a) AS a, f(b) AS b\n        FROM df_simple\n        \"\"\"\n    )\n    expected_df = df_simple[[\"a\", \"b\"]] ** 3\n\n    assert_eq(return_df, expected_df)\n\n\ndef test_aggregate_function(c):\n    fagg = dd.Aggregation(\"f\", lambda x: x.sum(), lambda x: x.sum())\n    c.register_aggregation(fagg, \"fagg\", [(\"x\", np.float64)], np.float64)\n\n    return_df = c.sql(\n        \"\"\"\n        SELECT FAGG(b) AS test, SUM(b) AS \"S\"\n        FROM df\n        \"\"\"\n    )\n\n    assert_eq(return_df[\"test\"], return_df[\"S\"], check_names=False)\n\n\ndef test_reregistration(c):\n    def f(x):\n        return x**2\n\n    # The same is fine\n    c.register_function(f, \"f\", [(\"x\", np.float64)], np.float64)\n    c.register_function(f, \"f\", [(\"x\", np.int64)], np.int64)\n\n    def f(x):\n        return x**3\n\n    # A different not\n    with pytest.raises(ValueError):\n        c.register_function(f, \"f\", [(\"x\", np.float64)], np.float64)\n\n    # only if we replace it\n    c.register_function(f, \"f\", [(\"x\", np.float64)], np.float64, replace=True)\n\n    fagg = dd.Aggregation(\"f\", lambda x: x.sum(), lambda x: x.sum())\n    c.register_aggregation(fagg, \"fagg\", [(\"x\", np.float64)], np.float64)\n    c.register_aggregation(fagg, \"fagg\", [(\"x\", np.int64)], np.int64)\n\n    fagg = dd.Aggregation(\"f\", lambda x: x.mean(), lambda x: x.mean())\n\n    with pytest.raises(ValueError):\n        c.register_aggregation(fagg, \"fagg\", [(\"x\", np.float64)], np.float64)\n\n    c.register_aggregation(fagg, \"fagg\", [(\"x\", np.float64)], np.float64, replace=True)\n\n\n@pytest.mark.parametrize(\"dtype\", [np.timedelta64, None, \"a string\"])\ndef test_unsupported_dtype(c, dtype):\n    def f(x):\n        return x**2\n\n    # test that an invalid return type raises\n    with pytest.raises(NotImplementedError):\n        c.register_function(f, \"f\", [(\"x\", np.int64)], dtype)\n\n    # test that an invalid param type raises\n    with pytest.raises(NotImplementedError):\n        c.register_function(f, \"f\", [(\"x\", dtype)], np.int64)\n\n\n# TODO: explore implicitly casting inputs to the expected types consistently\ndef test_wrong_input_type(c):\n    def f(a):\n        return a\n\n    c.register_function(f, \"f\", [(\"a\", np.int64)], np.int64)\n\n    with pytest.raises(ParsingException):\n        c.sql(\"SELECT F(CAST(a AS INT)) AS a FROM df\")\n"
  },
  {
    "path": "tests/integration/test_groupby.py",
    "content": "import dask.dataframe as dd\nimport numpy as np\nimport pandas as pd\nimport pytest\nfrom dask.datasets import timeseries\n\nfrom tests.utils import assert_eq\n\n\n@pytest.fixture()\ndef timeseries_df(c):\n    pdf = timeseries(freq=\"1d\").compute().reset_index(drop=True)\n\n    # input nans in pandas dataframe\n    col1_index = np.random.randint(0, 30, size=int(pdf.shape[0] * 0.2))\n    col2_index = np.random.randint(0, 30, size=int(pdf.shape[0] * 0.3))\n    pdf.loc[col1_index, \"x\"] = np.nan\n    pdf.loc[col2_index, \"y\"] = np.nan\n\n    c.create_table(\"timeseries\", pdf, persist=True)\n\n    return None\n\n\ndef test_group_by(c):\n    return_df = c.sql(\n        \"\"\"\n    SELECT\n        user_id, SUM(b) AS \"S\"\n    FROM user_table_1\n    GROUP BY user_id\n    \"\"\"\n    )\n    expected_df = pd.DataFrame({\"user_id\": [1, 2, 3], \"S\": [3, 4, 3]})\n\n    assert_eq(return_df.sort_values(\"user_id\").reset_index(drop=True), expected_df)\n\n\n@pytest.mark.parametrize(\"gpu\", [False, pytest.param(True, marks=pytest.mark.gpu)])\ndef test_group_by_multi(c, gpu):\n    df = pd.DataFrame({\"a\": [1, 2, 3], \"b\": [1, 1, 2]})\n    c.create_table(\"df\", df, gpu=gpu)\n\n    result_df = c.sql(\n        \"\"\"\n        SELECT\n            SUM(a) AS s,\n            AVG(a) AS av,\n            COUNT(a) AS c\n        FROM\n            df\n        GROUP BY\n            b\n        \"\"\"\n    )\n\n    expected_df = pd.DataFrame(\n        {\n            \"s\": df.groupby(\"b\").sum()[\"a\"],\n            \"av\": df.groupby(\"b\").mean()[\"a\"],\n            \"c\": df.groupby(\"b\").count()[\"a\"],\n        }\n    )\n\n    assert_eq(result_df, expected_df, check_index=False)\n\n    c.drop_table(\"df\")\n\n\ndef test_group_by_all(c, df):\n    result_df = c.sql(\n        \"\"\"\n    SELECT\n        SUM(b) AS \"S\", SUM(2) AS \"X\"\n    FROM user_table_1\n    \"\"\"\n    )\n    expected_df = pd.DataFrame({\"S\": [10], \"X\": [8]})\n\n    assert_eq(result_df, expected_df)\n\n    result_df = c.sql(\n        \"\"\"\n        SELECT\n            SUM(a) AS sum_a,\n            AVG(a) AS avg_a,\n            SUM(b) AS sum_b,\n            AVG(b) AS avg_b,\n            SUM(a)+AVG(b) AS mix_1,\n            SUM(a+b) AS mix_2,\n            AVG(a+b) AS mix_3\n        FROM df\n        \"\"\"\n    )\n    expected_df = pd.DataFrame(\n        {\n            \"sum_a\": [df.a.sum()],\n            \"avg_a\": [df.a.mean()],\n            \"sum_b\": [df.b.sum()],\n            \"avg_b\": [df.b.mean()],\n            \"mix_1\": [df.a.sum() + df.b.mean()],\n            \"mix_2\": [(df.a + df.b).sum()],\n            \"mix_3\": [(df.a + df.b).mean()],\n        }\n    )\n\n    assert_eq(result_df, expected_df)\n\n\ndef test_group_by_filtered(c):\n    return_df = c.sql(\n        \"\"\"\n    SELECT\n        SUM(b) FILTER (WHERE user_id = 2) AS \"S1\",\n        SUM(b) \"S2\"\n    FROM user_table_1\n    \"\"\"\n    )\n    expected_df = pd.DataFrame({\"S1\": [4], \"S2\": [10]}, dtype=\"int64\")\n\n    assert_eq(return_df, expected_df)\n\n    return_df = c.sql(\n        \"\"\"\n    SELECT\n        user_id,\n        SUM(b) FILTER (WHERE user_id = 2) AS \"S1\",\n        SUM(b) \"S2\"\n    FROM user_table_1\n    GROUP BY user_id\n    \"\"\"\n    )\n    expected_df = pd.DataFrame(\n        {\n            \"user_id\": [1, 2, 3],\n            \"S1\": [np.NaN, 4.0, np.NaN],\n            \"S2\": [3, 4, 3],\n        },\n    )\n\n    assert_eq(return_df, expected_df, check_index=False)\n\n    return_df = c.sql(\n        \"\"\"\n    SELECT\n        SUM(b) FILTER (WHERE user_id = 2) AS \"S1\"\n    FROM user_table_1\n    \"\"\"\n    )\n    expected_df = pd.DataFrame({\"S1\": [4]})\n    assert_eq(return_df, expected_df)\n\n\n@pytest.mark.xfail(reason=\"WIP DataFusion\")\ndef test_group_by_case(c):\n    return_df = c.sql(\n        \"\"\"\n    SELECT\n        user_id + 1 AS \"A\", SUM(CASE WHEN b = 3 THEN 1 END) AS \"S\"\n    FROM user_table_1\n    GROUP BY user_id + 1\n    \"\"\"\n    )\n    expected_df = pd.DataFrame({\"A\": [2, 3, 4], \"S\": [1, 1, 1]})\n\n    # Do not check dtypes, as pandas versions are inconsistent here\n    assert_eq(\n        return_df.sort_values(\"A\").reset_index(drop=True),\n        expected_df,\n        check_dtype=False,\n    )\n\n\ndef test_group_by_nan(c, user_table_nan):\n    return_df = c.sql(\n        \"\"\"\n    SELECT\n        c\n    FROM user_table_nan\n    GROUP BY c\n    \"\"\"\n    )\n    expected_df = user_table_nan.drop_duplicates(subset=[\"c\"])\n\n    # we return nullable int dtype instead of float\n    assert_eq(return_df, expected_df, check_dtype=False)\n\n    return_df = c.sql(\n        \"\"\"\n    SELECT\n        c\n    FROM user_table_inf\n    GROUP BY c\n    \"\"\"\n    )\n    expected_df = pd.DataFrame({\"c\": [3, 1, float(\"inf\")]})\n    expected_df[\"c\"] = expected_df[\"c\"].astype(\"float64\")\n\n    assert_eq(\n        return_df.sort_values(\"c\").reset_index(drop=True),\n        expected_df.sort_values(\"c\").reset_index(drop=True),\n    )\n\n\ndef test_aggregations(c):\n    return_df = c.sql(\n        \"\"\"\n    SELECT\n        user_id,\n        EVERY(b = 3) AS e,\n        BIT_AND(b) AS b,\n        BIT_OR(b) AS bb,\n        MIN(b) AS m,\n        SINGLE_VALUE(b) AS s,\n        AVG(b) AS a\n    FROM user_table_1\n    GROUP BY user_id\n    \"\"\"\n    )\n    expected_df = pd.DataFrame(\n        {\n            \"user_id\": [1, 2, 3],\n            \"e\": [True, False, True],\n            \"b\": [3, 1, 3],\n            \"bb\": [3, 3, 3],\n            \"m\": [3, 1, 3],\n            \"s\": [3, 3, 3],\n            \"a\": [3, 2, 3],\n        }\n    )\n    expected_df[\"a\"] = expected_df[\"a\"].astype(\"float64\")\n\n    assert_eq(return_df.sort_values(\"user_id\").reset_index(drop=True), expected_df)\n\n    return_df = c.sql(\n        \"\"\"\n    SELECT\n        user_id,\n        EVERY(c = 3) AS e,\n        BIT_AND(c) AS b,\n        BIT_OR(c) AS bb,\n        MIN(c) AS m,\n        SINGLE_VALUE(c) AS s,\n        AVG(c) AS a\n    FROM user_table_2\n    GROUP BY user_id\n    \"\"\"\n    )\n\n    expected_df = pd.DataFrame(\n        {\n            \"user_id\": [1, 2, 4],\n            \"e\": [False, True, False],\n            \"b\": [0, 3, 4],\n            \"bb\": [3, 3, 4],\n            \"m\": [1, 3, 4],\n            \"s\": [1, 3, 4],\n            \"a\": [1.5, 3, 4],\n        }\n    )\n    assert_eq(return_df.sort_values(\"user_id\").reset_index(drop=True), expected_df)\n\n    return_df = c.sql(\n        \"\"\"\n    SELECT\n        MAX(a) AS \"max\",\n        MIN(a) AS \"min\"\n    FROM string_table\n    \"\"\"\n    )\n    expected_df = pd.DataFrame({\"max\": [\"a normal string\"], \"min\": [\"%_%\"]})\n\n    assert_eq(return_df.reset_index(drop=True), expected_df)\n\n\n@pytest.mark.parametrize(\n    \"gpu\",\n    [\n        False,\n        pytest.param(\n            True,\n            marks=(\n                pytest.mark.gpu,\n                pytest.mark.xfail(\n                    reason=\"stddev_pop is failing on GPU, see https://github.com/dask-contrib/dask-sql/issues/681\"\n                ),\n            ),\n        ),\n    ],\n)\ndef test_stddev(c, gpu):\n    df = pd.DataFrame(\n        {\n            \"a\": [1, 1, 2, 1, 2],\n            \"b\": [4, 6, 3, 8, 5],\n        }\n    )\n\n    c.create_table(\"df\", df, gpu=gpu)\n\n    return_df = c.sql(\n        \"\"\"\n        SELECT\n            STDDEV(b) AS s\n        FROM df\n        GROUP BY df.a\n        \"\"\"\n    )\n\n    expected_df = pd.DataFrame({\"s\": df.groupby(\"a\").std()[\"b\"]})\n\n    assert_eq(return_df, expected_df, check_index=False)\n\n    return_df = c.sql(\n        \"\"\"\n        SELECT\n            STDDEV_SAMP(b) AS ss\n        FROM df\n        \"\"\"\n    )\n\n    expected_df = pd.DataFrame({\"ss\": [df.std()[\"b\"]]})\n\n    assert_eq(return_df, expected_df.reset_index(drop=True))\n\n    return_df = c.sql(\n        \"\"\"\n        SELECT\n            STDDEV_POP(b) AS sp\n        FROM df\n        GROUP BY df.a\n        \"\"\"\n    )\n\n    expected_df = pd.DataFrame({\"sp\": df.groupby(\"a\").std(ddof=0)[\"b\"]})\n\n    assert_eq(return_df, expected_df.reset_index(drop=True))\n\n    return_df = c.sql(\n        \"\"\"\n        SELECT\n            STDDEV(a) as s,\n            STDDEV_SAMP(a) ss,\n            STDDEV_POP(b) sp\n        FROM\n            df\n        \"\"\"\n    )\n\n    expected_df = pd.DataFrame(\n        {\n            \"s\": [df.std()[\"a\"]],\n            \"ss\": [df.std()[\"a\"]],\n            \"sp\": [df.std(ddof=0)[\"b\"]],\n        }\n    )\n\n    assert_eq(return_df, expected_df.reset_index(drop=True))\n\n    c.drop_table(\"df\")\n\n\n@pytest.mark.parametrize(\"gpu\", [False, pytest.param(True, marks=pytest.mark.gpu)])\ndef test_regr_aggregation(c, timeseries_df, gpu):\n    # test regr_count\n    regr_count = c.sql(\n        \"\"\"\n    SELECT\n        name,\n        COUNT(x) FILTER (WHERE y IS NOT NULL) AS expected,\n        REGR_COUNT(y, x) AS calculated\n    FROM timeseries\n    GROUP BY name\n    \"\"\"\n    ).fillna(0)\n\n    assert_eq(\n        regr_count[\"expected\"],\n        regr_count[\"calculated\"],\n        check_dtype=False,\n        check_names=False,\n    )\n\n    # test regr_syy\n    regr_syy = c.sql(\n        \"\"\"\n    SELECT\n        name,\n        (REGR_COUNT(y, x) * VAR_POP(y)) AS expected,\n        REGR_SYY(y, x) AS calculated\n    FROM timeseries\n    WHERE x IS NOT NULL AND y IS NOT NULL\n    GROUP BY name\n    \"\"\"\n    ).fillna(0)\n\n    assert_eq(\n        regr_syy[\"expected\"],\n        regr_syy[\"calculated\"],\n        check_dtype=False,\n        check_names=False,\n    )\n\n    # test regr_sxx\n    regr_sxx = c.sql(\n        \"\"\"\n    SELECT\n        name,\n        (REGR_COUNT(y, x) * VAR_POP(x)) AS expected,\n        REGR_SXX(y,x) AS calculated\n    FROM timeseries\n    WHERE x IS NOT NULL AND y IS NOT NULL\n    GROUP BY name\n    \"\"\"\n    ).fillna(0)\n\n    assert_eq(\n        regr_sxx[\"expected\"],\n        regr_sxx[\"calculated\"],\n        check_dtype=False,\n        check_names=False,\n    )\n\n\n@pytest.mark.xfail(\n    reason=\"WIP DataFusion - https://github.com/dask-contrib/dask-sql/issues/753\"\n)\ndef test_covar_aggregation(c, timeseries_df):\n    # test covar_pop\n    covar_pop = c.sql(\n        \"\"\"\n    WITH temp_agg AS (\n        SELECT\n            name,\n            AVG(y) FILTER (WHERE x IS NOT NULL) as avg_y,\n            AVG(x) FILTER (WHERE x IS NOT NULL) as avg_x\n        FROM timeseries\n        GROUP BY name\n    ) SELECT\n        ts.name,\n        SUM((y - avg_y) * (x - avg_x)) / REGR_COUNT(y, x) AS expected,\n        COVAR_POP(y,x) AS calculated\n    FROM timeseries AS ts\n    JOIN temp_agg AS ta ON ts.name = ta.name\n    GROUP BY ts.name\n    \"\"\"\n    ).fillna(0)\n\n    assert_eq(\n        covar_pop[\"expected\"],\n        covar_pop[\"calculated\"],\n        check_dtype=False,\n        check_names=False,\n    )\n\n    # test covar_samp\n    covar_samp = c.sql(\n        \"\"\"\n    WITH temp_agg AS (\n        SELECT\n            name,\n            AVG(y) FILTER (WHERE x IS NOT NULL) as avg_y,\n            AVG(x) FILTER (WHERE x IS NOT NULL) as avg_x\n        FROM timeseries\n        GROUP BY name\n    ) SELECT\n        ts.name,\n        SUM((y - avg_y) * (x - avg_x)) / (REGR_COUNT(y, x) - 1) as expected,\n        COVAR_SAMP(y,x) AS calculated\n    FROM timeseries AS ts\n    JOIN temp_agg AS ta ON ts.name = ta.name\n    GROUP BY ts.name\n    \"\"\"\n    ).fillna(0)\n\n    assert_eq(\n        covar_samp[\"expected\"],\n        covar_samp[\"calculated\"],\n        check_dtype=False,\n        check_names=False,\n    )\n\n\n@pytest.mark.parametrize(\n    \"input_table\",\n    [\n        \"user_table_1\",\n        pytest.param(\"gpu_user_table_1\", marks=pytest.mark.gpu),\n    ],\n)\n@pytest.mark.parametrize(\"split_out\", [1, 2, 4])\ndef test_groupby_split_out(c, input_table, split_out, request):\n    user_table = request.getfixturevalue(input_table)\n\n    return_df = c.sql(\n        f\"\"\"\n        SELECT\n        user_id, SUM(b) AS \"S\"\n        FROM {input_table}\n        GROUP BY user_id\n        \"\"\",\n        config_options={\"sql.aggregate.split_out\": split_out} if split_out else {},\n    )\n    expected_df = (\n        user_table.groupby(by=\"user_id\")\n        .agg({\"b\": \"sum\"})\n        .reset_index(drop=False)\n        .rename(columns={\"b\": \"S\"})\n        .sort_values(\"user_id\")\n    )\n\n    assert return_df.npartitions == split_out if split_out else 1\n    assert_eq(return_df.sort_values(\"user_id\"), expected_df, check_index=False)\n\n    return_df = c.sql(\n        f\"\"\"\n        SELECT DISTINCT(user_id) FROM {input_table}\n        \"\"\",\n        config_options={\"sql.aggregate.split_out\": split_out},\n    )\n    expected_df = user_table[[\"user_id\"]].drop_duplicates()\n    assert return_df.npartitions == split_out if split_out else 1\n    assert_eq(return_df.sort_values(\"user_id\"), expected_df, check_index=False)\n\n\n@pytest.mark.parametrize(\"gpu\", [False, pytest.param(True, marks=pytest.mark.gpu)])\ndef test_groupby_split_every(c, gpu):\n    input_ddf = dd.from_pandas(\n        pd.DataFrame({\"user_id\": [1, 2, 3, 4] * 16, \"b\": [5, 6, 7, 8] * 16}),\n        npartitions=16,\n    )  # Need an input with multiple partitions to demonstrate split_every\n\n    c.create_table(\"split_every_input\", input_ddf, gpu=gpu)\n\n    query_string = \"\"\"\n    SELECT\n        user_id, SUM(b) AS \"S\"\n    FROM split_every_input\n    GROUP BY user_id\n    \"\"\"\n    split_every_2_df = c.sql(\n        query_string,\n        config_options={\"sql.aggregate.split_every\": 2},\n    )\n    split_every_3_df = c.sql(\n        query_string,\n        config_options={\"sql.aggregate.split_every\": 3},\n    )\n    split_every_4_df = c.sql(\n        query_string,\n        config_options={\"sql.aggregate.split_every\": 4},\n    )\n\n    expected_df = (\n        input_ddf.groupby(by=\"user_id\")\n        .agg({\"b\": \"sum\"})\n        .reset_index(drop=False)\n        .rename(columns={\"b\": \"S\"})\n        .sort_values(\"user_id\")\n    )\n    assert (\n        len(split_every_2_df.dask.keys())\n        >= len(split_every_3_df.dask.keys())\n        >= len(split_every_4_df.dask.keys())\n    )\n\n    assert_eq(split_every_2_df, expected_df, check_index=False)\n    assert_eq(split_every_3_df, expected_df, check_index=False)\n    assert_eq(split_every_4_df, expected_df, check_index=False)\n\n    query_string = \"\"\"\n    SELECT DISTINCT(user_id) FROM split_every_input\n    \"\"\"\n    split_every_2_df = c.sql(\n        query_string,\n        config_options={\"sql.aggregate.split_every\": 2},\n    )\n    split_every_3_df = c.sql(\n        query_string,\n        config_options={\"sql.aggregate.split_every\": 3},\n    )\n    split_every_4_df = c.sql(\n        query_string,\n        config_options={\"sql.aggregate.split_every\": 4},\n    )\n\n    expected_df = input_ddf[[\"user_id\"]].drop_duplicates()\n\n    assert (\n        len(split_every_2_df.dask.keys())\n        >= len(split_every_3_df.dask.keys())\n        >= len(split_every_4_df.dask.keys())\n    )\n    assert_eq(split_every_2_df, expected_df, check_index=False)\n    assert_eq(split_every_3_df, expected_df, check_index=False)\n    assert_eq(split_every_4_df, expected_df, check_index=False)\n\n    c.drop_table(\"split_every_input\")\n\n\n@pytest.mark.parametrize(\"gpu\", [False, pytest.param(True, marks=pytest.mark.gpu)])\ndef test_agg_decimal(c, gpu):\n    df = pd.DataFrame(\n        {\n            \"a\": [1.23, 12.65, 134.64, -34.3, 945.19],\n            \"b\": [1, 1, 2, 2, 3],\n        }\n    )\n\n    c.create_table(\"df\", df, gpu=gpu)\n\n    result_df = c.sql(\n        \"\"\"\n        SELECT\n            SUM(CAST(a AS DECIMAL)) as s,\n            COUNT(CAST(a AS DECIMAL)) as c,\n            SUM(CAST(a+a AS DECIMAL)) as s2\n        FROM\n            df\n        GROUP BY\n            b\n        \"\"\"\n    )\n    # decimal precision doesn't match up with pandas floats\n    if gpu:\n        result_df[\"s\"] = result_df[\"s\"].astype(\"float64\")\n        result_df[\"s2\"] = result_df[\"s2\"].astype(\"float64\")\n\n    expected_df = pd.DataFrame(\n        {\n            \"s\": df.groupby(\"b\").sum()[\"a\"],\n            \"c\": df.groupby(\"b\").count()[\"a\"],\n            \"s2\": df.groupby(\"b\").sum()[\"a\"] + df.groupby(\"b\").sum()[\"a\"],\n        }\n    )\n\n    # dtype of count aggregation is float on gpu\n    assert_eq(result_df, expected_df, check_index=False, check_dtype=(not gpu))\n\n    result_df = c.sql(\n        \"\"\"\n        SELECT\n            MIN(CAST(a AS DECIMAL)) as min,\n            MAX(CAST(a AS DECIMAL)) as max\n        FROM\n            df\n        \"\"\"\n    )\n    # decimal precision doesn't match up with pandas floats\n    if gpu:\n        result_df[\"min\"] = result_df[\"min\"].astype(\"float64\")\n        result_df[\"max\"] = result_df[\"max\"].astype(\"float64\")\n\n    expected_df = pd.DataFrame(\n        {\n            \"min\": [df.a.min()],\n            \"max\": [df.a.max()],\n        }\n    )\n\n    assert_eq(result_df, expected_df)\n    c.drop_table(\"df\")\n"
  },
  {
    "path": "tests/integration/test_hive.py",
    "content": "import shutil\nimport sys\nimport tempfile\nimport time\n\nimport pandas as pd\nimport pytest\n\nfrom dask_sql.context import Context\nfrom tests.utils import assert_eq\n\npytestmark = pytest.mark.xfail(\n    condition=sys.platform in (\"win32\", \"darwin\"),\n    reason=\"hive testing not supported on Windows/macOS\",\n)\ndocker = pytest.importorskip(\"docker\")\nsqlalchemy = pytest.importorskip(\"sqlalchemy\")\npytest.importorskip(\"pyhive\")\n\n\nDEFAULT_CONFIG = {\n    \"HIVE_SITE_CONF_javax_jdo_option_ConnectionURL\": \"jdbc:postgresql://hive-metastore-postgresql/metastore\",\n    \"HIVE_SITE_CONF_javax_jdo_option_ConnectionDriverName\": \"org.postgresql.Driver\",\n    \"HIVE_SITE_CONF_javax_jdo_option_ConnectionUserName\": \"hive\",\n    \"HIVE_SITE_CONF_javax_jdo_option_ConnectionPassword\": \"hive\",\n    \"HIVE_SITE_CONF_datanucleus_autoCreateSchema\": \"false\",\n    \"HIVE_SITE_CONF_hive_metastore_uris\": \"thrift://hive-metastore:9083\",\n    \"HDFS_CONF_dfs_namenode_datanode_registration_ip___hostname___check\": \"false\",\n    \"CORE_CONF_fs_defaultFS\": \"file:///database\",\n    \"CORE_CONF_hadoop_http_staticuser_user\": \"root\",\n    \"CORE_CONF_hadoop_proxyuser_hue_hosts\": \"*\",\n    \"CORE_CONF_hadoop_proxyuser_hue_groups\": \"*\",\n    \"HIVE_SITE_CONF_fs_default_name\": \"file:///database\",\n    \"CORE_CONF_fs_defaultFS\": \"file:///database\",\n    \"HIVE_SIZE_CONF_hive_metastore_warehouse_dir\": \"file:///database\",\n}\n\n\n@pytest.fixture(scope=\"session\")\ndef hive_cursor():\n    \"\"\"\n    Getting a hive setup up and running is a bit more complicated.\n    We need three running docker containers:\n    * a postgres database to store the metadata\n    * the metadata server itself\n    * and a server to answer SQL queries\n\n    They are all started one after the other to check,\n    if they are up and running.\n    We \"fake\" a network filesystem (instead of using hdfs),\n    by mounting a temporary folder from the host to the\n    docker container, which can be accessed both by hive\n    and the dask-sql client.\n\n    We just need to make sure, to remove all containers,\n    the network and the temporary folders correctly again.\n\n    The ideas for the docker setup are taken from the docker-compose\n    hive setup described by bde2020.\n    \"\"\"\n    client = docker.from_env()\n\n    network = None\n    hive_server = None\n    hive_metastore = None\n    hive_postgres = None\n\n    tmpdir = tempfile.mkdtemp()\n    tmpdir_parted = tempfile.mkdtemp()\n    tmpdir_multiparted = tempfile.mkdtemp()\n\n    try:\n        network = client.networks.create(\"dask-sql-hive\", driver=\"bridge\")\n\n        hive_server = client.containers.create(\n            \"bde2020/hive:2.3.2-postgresql-metastore\",\n            hostname=\"hive-server\",\n            name=\"hive-server\",\n            network=\"dask-sql-hive\",\n            volumes=[\n                f\"{tmpdir}:{tmpdir}\",\n                f\"{tmpdir_parted}:{tmpdir_parted}\",\n                f\"{tmpdir_multiparted}:{tmpdir_multiparted}\",\n            ],\n            environment={\n                \"HIVE_CORE_CONF_javax_jdo_option_ConnectionURL\": \"jdbc:postgresql://hive-metastore-postgresql/metastore\",\n                **DEFAULT_CONFIG,\n            },\n        )\n\n        hive_metastore = client.containers.create(\n            \"bde2020/hive:2.3.2-postgresql-metastore\",\n            hostname=\"hive-metastore\",\n            name=\"hive-metastore\",\n            network=\"dask-sql-hive\",\n            environment=DEFAULT_CONFIG,\n            command=\"/opt/hive/bin/hive --service metastore\",\n        )\n\n        hive_postgres = client.containers.create(\n            \"bde2020/hive-metastore-postgresql:2.3.0\",\n            hostname=\"hive-metastore-postgresql\",\n            name=\"hive-metastore-postgresql\",\n            network=\"dask-sql-hive\",\n        )\n\n        # Wait for it to start\n        hive_postgres.start()\n        hive_postgres.exec_run([\"bash\"])\n        for l in hive_postgres.logs(stream=True):\n            if b\"ready for start up.\" in l:\n                break\n\n        hive_metastore.start()\n        hive_metastore.exec_run([\"bash\"])\n        for l in hive_metastore.logs(stream=True):\n            if b\"Starting hive metastore\" in l:\n                break\n\n        hive_server.start()\n        hive_server.exec_run([\"bash\"])\n        for l in hive_server.logs(stream=True):\n            if b\"Starting HiveServer2\" in l:\n                break\n\n        # The server needs some time to start.\n        # It is easier to check for the first access\n        # on the metastore than to wait some\n        # arbitrary time.\n        for l in hive_metastore.logs(stream=True):\n            if b\"get_multi_table\" in l:\n                break\n\n        time.sleep(2)\n\n        hive_server.reload()\n        address = hive_server.attrs[\"NetworkSettings\"][\"Networks\"][\"dask-sql-hive\"][\n            \"IPAddress\"\n        ]\n        port = 10000\n        cursor = sqlalchemy.create_engine(f\"hive://{address}:{port}\").connect()\n\n        # Create a non-partitioned column\n        cursor.execute(\n            sqlalchemy.text(\n                f\"CREATE TABLE df (i INTEGER, j INTEGER) ROW FORMAT DELIMITED STORED AS PARQUET LOCATION '{tmpdir}'\"\n            )\n        )\n        cursor.execute(sqlalchemy.text(\"INSERT INTO df (i, j) VALUES (1, 2)\"))\n        cursor.execute(sqlalchemy.text(\"INSERT INTO df (i, j) VALUES (2, 4)\"))\n\n        cursor.execute(\n            sqlalchemy.text(\n                f\"CREATE TABLE df_part (i INTEGER) PARTITIONED BY (j INTEGER) ROW FORMAT DELIMITED STORED AS PARQUET LOCATION '{tmpdir_parted}'\"\n            )\n        )\n        cursor.execute(\n            sqlalchemy.text(\"INSERT INTO df_part PARTITION (j=2) (i) VALUES (1)\")\n        )\n        cursor.execute(\n            sqlalchemy.text(\"INSERT INTO df_part PARTITION (j=4) (i) VALUES (2)\")\n        )\n\n        cursor.execute(\n            sqlalchemy.text(\n                f\"\"\"\n            CREATE TABLE df_parts (i INTEGER) PARTITIONED BY (j INTEGER, k STRING)\n            ROW FORMAT DELIMITED STORED AS PARQUET LOCATION '{tmpdir_multiparted}'\n            \"\"\"\n            )\n        )\n        cursor.execute(\n            sqlalchemy.text(\n                \"INSERT INTO df_parts PARTITION (j=1, k='a') (i) VALUES (1)\"\n            )\n        )\n        cursor.execute(\n            sqlalchemy.text(\n                \"INSERT INTO df_parts PARTITION (j=2, k='b') (i) VALUES (2)\"\n            )\n        )\n\n        # The data files are created as root user by default. Change that:\n        hive_server.exec_run([\"chmod\", \"a+rwx\", \"-R\", tmpdir])\n        hive_server.exec_run([\"chmod\", \"a+rwx\", \"-R\", tmpdir_parted])\n        hive_server.exec_run([\"chmod\", \"a+rwx\", \"-R\", tmpdir_multiparted])\n\n        yield cursor\n    except docker.errors.ImageNotFound:\n        pytest.skip(\n            \"Hive testing requires 'bde2020/hive:2.3.2-postgresql-metastore' and \"\n            \"'bde2020/hive-metastore-postgresql:2.3.0' docker images\"\n        )\n    finally:\n        # Now clean up: remove the containers and the network and the folders\n        for container in [hive_server, hive_metastore, hive_postgres]:\n            if container is None:\n                continue\n\n            try:\n                container.kill()\n            except Exception:\n                pass\n\n            container.remove()\n\n        if network is not None:\n            network.remove()\n\n        shutil.rmtree(tmpdir)\n        shutil.rmtree(tmpdir_parted)\n\n\ndef test_select(hive_cursor):\n    c = Context()\n    c.create_table(\"df\", hive_cursor)\n\n    result_df = c.sql(\"SELECT * FROM df\")\n    expected_df = pd.DataFrame({\"i\": [1, 2], \"j\": [2, 4]}).astype(\"int32\")\n\n    assert_eq(result_df, expected_df, check_index=False)\n\n\ndef test_select_partitions(hive_cursor):\n    c = Context()\n    c.create_table(\"df_part\", hive_cursor)\n\n    result_df = c.sql(\"SELECT * FROM df_part\")\n    expected_df = pd.DataFrame({\"i\": [1, 2], \"j\": [2, 4]}).astype(\"int32\")\n    expected_df[\"j\"] = expected_df[\"j\"].astype(\"int64\")\n\n    assert_eq(result_df, expected_df, check_index=False)\n\n\ndef test_select_multipartitions(hive_cursor):\n    c = Context()\n    c.create_table(\"df_parts\", hive_cursor)\n\n    result_df = c.sql(\"SELECT * FROM df_parts\")\n    expected_df = pd.DataFrame({\"i\": [1, 2], \"j\": [1, 2], \"k\": [\"a\", \"b\"]})\n    expected_df[\"i\"] = expected_df[\"i\"].astype(\"int32\")\n    expected_df[\"j\"] = expected_df[\"j\"].astype(\"int64\")\n    expected_df[\"k\"] = expected_df[\"k\"].astype(\"object\")\n\n    assert_eq(result_df, expected_df, check_index=False)\n"
  },
  {
    "path": "tests/integration/test_intake.py",
    "content": "import os\nimport shutil\nimport tempfile\n\nimport pandas as pd\nimport pytest\n\nfrom dask_sql.context import Context\nfrom tests.utils import assert_eq, skipif_dask_expr_enabled\n\n# intake doesn't yet have proper dask-expr support\npytestmark = skipif_dask_expr_enabled(\n    reason=\"Intake doesn't yet have proper dask-expr support\"\n)\n\n# skip the test if intake is not installed\nintake = pytest.importorskip(\"intake\")\n\n\n@pytest.fixture()\ndef intake_catalog_location():\n    tmpdir = tempfile.mkdtemp()\n\n    df = pd.DataFrame({\"a\": [1], \"b\": [1.5]})\n\n    csv_location = os.path.join(tmpdir, \"data.csv\")\n    df.to_csv(csv_location, index=False)\n\n    yaml_location = os.path.join(tmpdir, \"catalog.yaml\")\n    with open(yaml_location, \"w\") as f:\n        f.write(\n            \"\"\"sources:\n    intake_table:\n        args:\n            urlpath: \"{{ CATALOG_DIR }}/data.csv\"\n        description: \"Some Data\"\n        driver: intake.source.csv.CSVSource\n        \"\"\"\n        )\n\n    try:\n        yield yaml_location\n    finally:\n        shutil.rmtree(tmpdir)\n\n\ndef check_read_table(c):\n    result_df = c.sql(\"SELECT * FROM df\").reset_index(drop=True)\n    expected_df = pd.DataFrame({\"a\": [1], \"b\": [1.5]})\n\n    assert_eq(result_df, expected_df)\n\n\ndef test_intake_catalog(intake_catalog_location):\n    catalog = intake.open_catalog(intake_catalog_location)\n    c = Context()\n    c.create_table(\"df\", catalog, intake_table_name=\"intake_table\")\n\n    check_read_table(c)\n\n\ndef test_intake_location(intake_catalog_location):\n    c = Context()\n    c.create_table(\n        \"df\", intake_catalog_location, format=\"intake\", intake_table_name=\"intake_table\"\n    )\n\n    check_read_table(c)\n\n\ndef test_intake_sql(intake_catalog_location):\n    c = Context()\n    c.sql(\n        f\"\"\"\n        CREATE TABLE df WITH (\n         location = '{intake_catalog_location}',\n         format = 'intake',\n         intake_table_name = 'intake_table'\n        )\n    \"\"\"\n    )\n\n    check_read_table(c)\n"
  },
  {
    "path": "tests/integration/test_jdbc.py",
    "content": "from time import sleep\n\nimport pandas as pd\nimport pytest\n\nfrom dask_sql import Context\nfrom dask_sql.server.app import _init_app, app\nfrom dask_sql.server.presto_jdbc import create_meta_data\nfrom tests.integration.fixtures import DISTRIBUTED_TESTS\n\n# needed for the testclient\npytest.importorskip(\"requests\")\n\nschema = \"a_schema\"\ntable = \"a_table\"\n\n\n@pytest.fixture(scope=\"module\")\ndef c():\n    c = Context()\n    c.create_schema(schema)\n    tables = pd.DataFrame(create_table_row(), index=[0])\n    tables = tables.astype({\"AN_INT\": \"int64\"})\n    c.create_table(table, tables, schema_name=schema)\n\n    yield c\n\n    c.drop_schema(schema)\n\n\n@pytest.fixture(scope=\"module\")\ndef app_client(c):\n    c.sql(\"SELECT 1 + 1\").compute()\n    _init_app(app, c)\n    # late import for the importskip\n    from fastapi.testclient import TestClient\n\n    yield TestClient(app)\n\n    # avoid closing client it's session-wide\n    if not DISTRIBUTED_TESTS:\n        app.client.close()\n\n\n@pytest.mark.xfail(reason=\"WIP DataFusion\")\ndef test_jdbc_has_schema(app_client, c):\n    create_meta_data(c)\n\n    check_data(app_client)\n\n    response = app_client.post(\n        \"/v1/statement\", data=\"SELECT * from system.jdbc.schemas\"\n    )\n    assert response.status_code == 200\n    result = get_result_or_error(app_client, response)\n\n    assert_result(result, 2, 3)\n    assert result[\"columns\"] == [\n        {\n            \"name\": \"TABLE_CATALOG\",\n            \"type\": \"varchar\",\n            \"typeSignature\": {\"rawType\": \"varchar\", \"arguments\": []},\n        },\n        {\n            \"name\": \"TABLE_SCHEM\",\n            \"type\": \"varchar\",\n            \"typeSignature\": {\"rawType\": \"varchar\", \"arguments\": []},\n        },\n    ]\n    assert result[\"data\"] == [\n        [\"\", \"root\"],\n        [\"\", \"a_schema\"],\n        [\"\", \"system_jdbc\"],\n    ]\n\n\ndef test_jdbc_has_table(app_client, c):\n    create_meta_data(c)\n    check_data(app_client)\n\n    response = app_client.post(\"/v1/statement\", data=\"SELECT * from system.jdbc.tables\")\n    assert response.status_code == 200\n    result = get_result_or_error(app_client, response)\n\n    assert_result(result, 10, 4)\n    assert result[\"data\"] == [\n        [\"\", \"a_schema\", \"a_table\", \"\", \"\", \"\", \"\", \"\", \"\", \"\"],\n        [\"\", \"system_jdbc\", \"schemas\", \"\", \"\", \"\", \"\", \"\", \"\", \"\"],\n        [\"\", \"system_jdbc\", \"tables\", \"\", \"\", \"\", \"\", \"\", \"\", \"\"],\n        [\"\", \"system_jdbc\", \"columns\", \"\", \"\", \"\", \"\", \"\", \"\", \"\"],\n    ]\n\n\n@pytest.mark.xfail(reason=\"WIP DataFusion\")\ndef test_jdbc_has_columns(app_client, c):\n    create_meta_data(c)\n    check_data(app_client)\n\n    response = app_client.post(\n        \"/v1/statement\",\n        data=f\"SELECT * from system.jdbc.columns where TABLE_NAME = '{table}'\",\n    )\n    assert response.status_code == 200\n    client_result = get_result_or_error(app_client, response)\n\n    # ordering of rows isn't consistent between fastapi versions\n    context_result = (\n        c.sql(\"SELECT * FROM system_jdbc.columns WHERE TABLE_NAME = 'a_table'\")\n        .compute()\n        .values.tolist()\n    )\n\n    assert_result(client_result, 24, 3)\n    assert client_result[\"data\"] == context_result\n\n\ndef assert_result(result, col_len, data_len):\n    assert \"columns\" in result\n    assert \"data\" in result\n    assert \"error\" not in result\n    assert len(result[\"columns\"]) == col_len\n    assert len(result[\"data\"]) == data_len\n\n\ndef create_table_row(a_str: str = \"any\", an_int: int = 1, a_float: float = 1.1):\n    return {\n        \"A_STR\": a_str,\n        \"AN_INT\": an_int,\n        \"A_FLOAT\": a_float,\n    }\n\n\ndef check_data(app_client):\n    response = app_client.post(\"/v1/statement\", data=f\"SELECT * from {schema}.{table}\")\n    assert response.status_code == 200\n    a_table = get_result_or_error(app_client, response)\n    assert \"columns\" in a_table\n    assert \"data\" in a_table\n    assert \"error\" not in a_table\n\n\ndef get_result_or_error(app_client, response):\n    result = response.json()\n\n    assert \"nextUri\" in result\n    assert \"error\" not in result\n\n    status_url = result[\"nextUri\"]\n    next_url = status_url\n\n    counter = 0\n    while True:\n        response = app_client.get(next_url)\n        assert response.status_code == 200\n\n        result = response.json()\n\n        if \"nextUri\" not in result:\n            break\n\n        next_url = result[\"nextUri\"]\n\n        counter += 1\n        assert counter <= 100\n\n        sleep(0.1)\n\n    return result\n"
  },
  {
    "path": "tests/integration/test_join.py",
    "content": "from contextlib import nullcontext\n\nimport dask.dataframe as dd\nimport numpy as np\nimport pandas as pd\nimport pytest\nfrom dask.utils_test import hlg_layer\n\nfrom dask_sql import Context\nfrom dask_sql.datacontainer import Statistics\nfrom tests.utils import assert_eq, skipif_dask_expr_enabled\n\n\ndef test_join(c):\n    return_df = c.sql(\n        \"\"\"\n    SELECT lhs.user_id, lhs.b, rhs.c\n    FROM user_table_1 AS lhs\n    JOIN user_table_2 AS rhs\n    ON lhs.user_id = rhs.user_id\n    \"\"\"\n    )\n    expected_df = pd.DataFrame(\n        {\"user_id\": [1, 1, 2, 2], \"b\": [3, 3, 1, 3], \"c\": [1, 2, 3, 3]}\n    )\n\n    assert_eq(return_df, expected_df, check_index=False)\n\n\ndef test_join_inner(c):\n    return_df = c.sql(\n        \"\"\"\n    SELECT lhs.user_id, lhs.b, rhs.c\n    FROM user_table_1 AS lhs\n    INNER JOIN user_table_2 AS rhs\n    ON lhs.user_id = rhs.user_id\n    \"\"\"\n    )\n    expected_df = pd.DataFrame(\n        {\"user_id\": [1, 1, 2, 2], \"b\": [3, 3, 1, 3], \"c\": [1, 2, 3, 3]}\n    )\n\n    assert_eq(return_df, expected_df, check_index=False)\n\n\ndef test_join_outer(c):\n    return_df = c.sql(\n        \"\"\"\n    SELECT lhs.user_id, lhs.b, rhs.c\n    FROM user_table_1 AS lhs\n    FULL JOIN user_table_2 AS rhs\n    ON lhs.user_id = rhs.user_id\n    \"\"\"\n    )\n    expected_df = pd.DataFrame(\n        {\n            # That is strange. Unfortunately, it seems dask fills in the\n            # missing rows with NaN, not with NA...\n            \"user_id\": [1, 1, 2, 2, 3, np.NaN],\n            \"b\": [3, 3, 1, 3, 3, np.NaN],\n            \"c\": [1, 2, 3, 3, np.NaN, 4],\n        }\n    )\n\n    assert_eq(return_df, expected_df, check_index=False)\n\n\ndef test_join_left(c):\n    return_df = c.sql(\n        \"\"\"\n    SELECT lhs.user_id, lhs.b, rhs.c\n    FROM user_table_1 AS lhs\n    LEFT JOIN user_table_2 AS rhs\n    ON lhs.user_id = rhs.user_id\n    \"\"\"\n    )\n    expected_df = pd.DataFrame(\n        {\n            # That is strange. Unfortunately, it seems dask fills in the\n            # missing rows with NaN, not with NA...\n            \"user_id\": [1, 1, 2, 2, 3],\n            \"b\": [3, 3, 1, 3, 3],\n            \"c\": [1, 2, 3, 3, np.NaN],\n        }\n    )\n\n    assert_eq(return_df, expected_df, check_index=False)\n\n\n@pytest.mark.parametrize(\"gpu\", [False, pytest.param(True, marks=pytest.mark.gpu)])\ndef test_join_left_anti(c, gpu):\n    df1 = pd.DataFrame({\"id\": [1, 1, 2, 4], \"a\": [\"a\", \"b\", \"c\", \"d\"]})\n    df2 = pd.DataFrame({\"id\": [2, 1, 2, 3], \"b\": [\"c\", \"c\", \"a\", \"c\"]})\n    c.create_table(\"df_1\", df1, gpu=gpu)\n    c.create_table(\"df_2\", df2, gpu=gpu)\n\n    return_df = c.sql(\n        \"\"\"\n    SELECT lhs.id, lhs.a\n    FROM df_1 AS lhs\n    LEFT ANTI JOIN df_2 AS rhs\n    ON lhs.id = rhs.id\n    \"\"\"\n    )\n    expected_df = pd.DataFrame(\n        {\n            \"id\": [4],\n            \"a\": [\"d\"],\n        }\n    )\n\n    assert_eq(return_df, expected_df, check_index=False)\n\n\n@pytest.mark.gpu\ndef test_join_left_semi(c):\n    df1 = pd.DataFrame({\"id\": [1, 1, 2, 4], \"a\": [\"a\", \"b\", \"c\", \"d\"]})\n    df2 = pd.DataFrame({\"id\": [2, 1, 2, 3], \"b\": [\"c\", \"c\", \"a\", \"c\"]})\n    c.create_table(\"df_1\", df1, gpu=True)\n    c.create_table(\"df_2\", df2, gpu=True)\n\n    return_df = c.sql(\n        \"\"\"\n    SELECT lhs.id, lhs.a\n    FROM df_1 AS lhs\n    LEFT SEMI JOIN df_2 AS rhs\n    ON lhs.id = rhs.id\n    \"\"\"\n    )\n    expected_df = pd.DataFrame(\n        {\n            \"id\": [1, 1, 2],\n            \"a\": [\"a\", \"b\", \"c\"],\n        }\n    )\n\n    assert_eq(return_df, expected_df, check_index=False)\n\n\ndef test_join_right(c):\n    return_df = c.sql(\n        \"\"\"\n    SELECT lhs.user_id, lhs.b, rhs.c\n    FROM user_table_1 AS lhs\n    RIGHT JOIN user_table_2 AS rhs\n    ON lhs.user_id = rhs.user_id\n    \"\"\"\n    )\n    expected_df = pd.DataFrame(\n        {\n            # That is strange. Unfortunately, it seems dask fills in the\n            # missing rows with NaN, not with NA...\n            \"user_id\": [1, 1, 2, 2, np.NaN],\n            \"b\": [3, 3, 1, 3, np.NaN],\n            \"c\": [1, 2, 3, 3, 4],\n        }\n    )\n\n    assert_eq(return_df, expected_df, check_index=False)\n\n\ndef test_join_cross(c, user_table_1, department_table):\n    return_df = c.sql(\n        \"\"\"\n    SELECT user_id, b, department_name\n    FROM user_table_1, department_table\n    \"\"\"\n    )\n\n    user_table_1[\"key\"] = 1\n    department_table[\"key\"] = 1\n\n    expected_df = dd.merge(user_table_1, department_table, on=\"key\").drop(columns=\"key\")\n\n    assert_eq(return_df, expected_df, check_index=False)\n\n\ndef test_join_complex(c):\n    return_df = c.sql(\n        \"\"\"\n    SELECT lhs.a, rhs.b\n    FROM df_simple AS lhs\n    JOIN df_simple AS rhs\n    ON lhs.a < rhs.b\n    \"\"\"\n    )\n    expected_df = pd.DataFrame(\n        {\"a\": [1, 1, 1, 2, 2, 3], \"b\": [1.1, 2.2, 3.3, 2.2, 3.3, 3.3]}\n    )\n\n    assert_eq(return_df, expected_df, check_index=False)\n\n    return_df = c.sql(\n        \"\"\"\n    SELECT lhs.a, lhs.b, rhs.a, rhs.b\n    FROM df_simple AS lhs\n    JOIN df_simple AS rhs\n    ON lhs.a < rhs.b AND lhs.b < rhs.a\n    \"\"\"\n    )\n    expected_df = pd.DataFrame(\n        {\n            \"lhs.a\": [1, 1, 2],\n            \"lhs.b\": [1.1, 1.1, 2.2],\n            \"rhs.a\": [2, 3, 3],\n            \"rhs.b\": [2.2, 3.3, 3.3],\n        }\n    )\n\n    assert_eq(return_df, expected_df, check_index=False)\n\n    return_df = c.sql(\n        \"\"\"\n    SELECT lhs.user_id, lhs.b, rhs.user_id, rhs.c\n    FROM user_table_1 AS lhs\n    JOIN user_table_2 AS rhs\n    ON rhs.user_id = lhs.user_id AND rhs.c - lhs.b >= 0\n    \"\"\"\n    )\n    expected_df = pd.DataFrame(\n        {\"lhs.user_id\": [2, 2], \"b\": [1, 3], \"rhs.user_id\": [2, 2], \"c\": [3, 3]}\n    )\n\n    assert_eq(return_df, expected_df, check_index=False)\n\n\ndef test_join_literal(c):\n    return_df = c.sql(\n        \"\"\"\n    SELECT lhs.user_id, lhs.b, rhs.user_id, rhs.c\n    FROM user_table_1 AS lhs\n    JOIN user_table_2 AS rhs\n    ON True\n    \"\"\"\n    )\n    expected_df = pd.DataFrame(\n        {\n            \"lhs.user_id\": [2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3],\n            \"b\": [1, 1, 1, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],\n            \"rhs.user_id\": [1, 1, 2, 4, 1, 1, 2, 4, 1, 1, 2, 4, 1, 1, 2, 4],\n            \"c\": [1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4],\n        }\n    )\n\n    assert_eq(return_df, expected_df, check_index=False)\n\n    return_df = c.sql(\n        \"\"\"\n    SELECT lhs.user_id, lhs.b, rhs.user_id, rhs.c\n    FROM user_table_1 AS lhs\n    JOIN user_table_2 AS rhs\n    ON False\n    \"\"\"\n    )\n    expected_df = pd.DataFrame({\"lhs.user_id\": [], \"b\": [], \"rhs.user_id\": [], \"c\": []})\n\n    assert_eq(return_df, expected_df, check_dtype=False, check_index=False)\n\n\ndef test_conditional_join(c):\n    df1 = pd.DataFrame({\"a\": [1, 2, 2, 5, 6], \"b\": [\"w\", \"x\", \"y\", None, \"z\"]})\n    df2 = pd.DataFrame({\"c\": [None, 3, 2, 5], \"d\": [\"h\", \"i\", \"j\", \"k\"]})\n\n    expected_df = pd.merge(df1, df2, how=\"inner\", left_on=[\"a\"], right_on=[\"c\"])\n    expected_df = expected_df[~pd.isnull(expected_df.b)]\n\n    c.create_table(\"df1\", df1)\n    c.create_table(\"df2\", df2)\n\n    actual_df = c.sql(\n        \"\"\"\n    SELECT * FROM df1\n    INNER JOIN df2 ON\n    (\n        a = c\n        AND b IS NOT NULL\n    )\n    \"\"\"\n    )\n\n    assert_eq(actual_df, expected_df, check_index=False, check_dtype=False)\n\n\ndef test_join_on_unary_cond_only(c):\n    df1 = pd.DataFrame({\"a\": [1, 2, 2, 5, 6], \"b\": [\"w\", \"x\", \"y\", None, \"z\"]})\n    df2 = pd.DataFrame({\"c\": [None, 3, 2, 5], \"d\": [\"h\", \"i\", \"j\", \"k\"]})\n\n    c.create_table(\"df1\", df1)\n    c.create_table(\"df2\", df2)\n\n    df1 = df1.assign(common=1)\n    df2 = df2.assign(common=1)\n\n    expected_df = df1.merge(df2, on=\"common\").drop(columns=\"common\")\n    expected_df = expected_df[~pd.isnull(expected_df.b)]\n\n    actual_df = c.sql(\"SELECT * FROM df1 INNER JOIN df2 ON b IS NOT NULL\")\n\n    assert_eq(actual_df, expected_df, check_index=False, check_dtype=False)\n\n\ndef test_join_case_projection_subquery():\n    c = Context()\n\n    # Tables for query\n    demo = pd.DataFrame({\"demo_sku\": [], \"hd_dep_count\": []})\n    site_page = pd.DataFrame({\"site_page_sk\": [], \"site_char_count\": []})\n    sales = pd.DataFrame(\n        {\"sales_hdemo_sk\": [], \"sales_page_sk\": [], \"sold_time_sk\": []}\n    )\n    t_dim = pd.DataFrame({\"t_time_sk\": [], \"t_hour\": []})\n\n    c.create_table(\"demos\", demo, persist=False)\n    c.create_table(\"site_page\", site_page, persist=False)\n    c.create_table(\"sales\", sales, persist=False)\n    c.create_table(\"t_dim\", t_dim, persist=False)\n\n    c.sql(\n        \"\"\"\n    SELECT CASE WHEN pmc > 0.0 THEN CAST (amc AS DOUBLE) / CAST (pmc AS DOUBLE) ELSE -1.0 END AS am_pm_ratio\n    FROM\n    (\n        SELECT SUM(amc1) AS amc, SUM(pmc1) AS pmc\n        FROM\n        (\n            SELECT\n                CASE WHEN t_hour BETWEEN 7 AND 8 THEN COUNT(1) ELSE 0 END AS amc1,\n                CASE WHEN t_hour BETWEEN 19 AND 20 THEN COUNT(1) ELSE 0 END AS pmc1\n            FROM sales ws\n            JOIN demos hd ON (hd.demo_sku = ws.sales_hdemo_sk and hd.hd_dep_count = 5)\n            JOIN site_page sp ON (sp.site_page_sk = ws.sales_page_sk and sp.site_char_count BETWEEN 5000 AND 6000)\n            JOIN t_dim td ON (td.t_time_sk = ws.sold_time_sk and td.t_hour IN (7,8,19,20))\n            GROUP BY t_hour\n        ) cnt_am_pm\n    ) sum_am_pm\n    \"\"\"\n    ).compute()\n\n\ndef test_conditional_join_with_limit(c):\n    df = pd.DataFrame({\"a\": [1, 2, 3, 4], \"b\": [5, 6, 7, 8]})\n    ddf = dd.from_pandas(df, 5)\n\n    c.create_table(\"many_partitions\", ddf)\n\n    df = df.assign(common=1)\n    expected_df = df.merge(df, on=\"common\", suffixes=(\"\", \"0\")).drop(columns=\"common\")\n    expected_df = expected_df[expected_df[\"a\"] >= 2][:4]\n\n    # Columns are renamed to use their fully qualified names which is more accurate\n    expected_df = expected_df.rename(\n        columns={\"a\": \"df1.a\", \"b\": \"df1.b\", \"a0\": \"df2.a\", \"b0\": \"df2.b\"}\n    )\n\n    actual_df = c.sql(\n        \"\"\"\n    SELECT * FROM\n        many_partitions as df1, many_partitions as df2\n    WHERE\n        df1.\"a\" >= 2\n    LIMIT 4\n    \"\"\"\n    )\n\n    assert_eq(actual_df, expected_df, check_index=False)\n\n\n@pytest.mark.filterwarnings(\n    \"ignore:You are merging on int and float:UserWarning:dask.dataframe.multi\"\n)\ndef test_intersect(c):\n\n    # Join df_simple against itself\n    actual_df = c.sql(\n        \"\"\"\n    select count(*) from (\n    select * from df_simple\n    intersect\n    select * from df_simple\n    ) hot_item\n    limit 100\n    \"\"\"\n    )\n    assert actual_df[\"COUNT(*)\"].compute()[0] == 3\n\n    # Join df_simple against itself, and then that result against df_wide. Nothing should match so therefore result should be 0\n    actual_df = c.sql(\n        \"\"\"\n    select count(*) from (\n    select a, b from df_simple\n    intersect\n    select a, b from df_simple\n    intersect\n    select a, b from df_wide\n    ) hot_item\n    limit 100\n    \"\"\"\n    )\n    assert len(actual_df[\"COUNT(*)\"]) == 0\n\n    actual_df = c.sql(\n        \"\"\"\n    select * from df_simple intersect select * from df_simple\n    \"\"\"\n    )\n    assert actual_df.shape[0].compute() == 3\n\n\ndef test_intersect_multi_col(c):\n    df1 = pd.DataFrame({\"a\": [1, 2, 3], \"b\": [4, 5, 6], \"c\": [7, 8, 9]})\n    df2 = pd.DataFrame({\"a\": [1, 1, 1], \"b\": [4, 5, 6], \"c\": [7, 7, 7]})\n\n    c.create_table(\"df1\", df1)\n    c.create_table(\"df2\", df2)\n\n    return_df = c.sql(\"select * from df1 intersect select * from df2\")\n    expected_df = pd.DataFrame(\n        {\n            \"df1.a\": [1],\n            \"df1.b\": [4],\n            \"df1.c\": [7],\n            \"df2.a\": [1],\n            \"df2.b\": [4],\n            \"df2.c\": [7],\n        }\n    )\n    assert_eq(return_df, expected_df, check_index=False)\n\n\n# TODO: remove this marker once fix for dask-expr#1018 is released\n# see: https://github.com/dask/dask-expr/issues/1018\n@skipif_dask_expr_enabled(\"Waiting for fix to dask-expr#1018\")\ndef test_join_alias_w_projection(c, parquet_ddf):\n    result_df = c.sql(\n        \"SELECT t2.c as c_y from parquet_ddf t1, parquet_ddf t2 WHERE t1.a=t2.a and t1.c='A'\"\n    )\n    expected_df = parquet_ddf.merge(parquet_ddf, on=[\"a\"], how=\"inner\")\n    expected_df = expected_df[expected_df[\"c_x\"] == \"A\"][[\"c_y\"]]\n    assert_eq(result_df, expected_df, check_index=False)\n\n\ndef test_filter_columns_post_join(c):\n    df = pd.DataFrame({\"a\": [1, 2, 3, 4, 5], \"c\": [1, None, 2, 2, 2]})\n    df2 = pd.DataFrame({\"b\": [1, 1, 2, 2, 3], \"c\": [2, 2, 2, 2, 2]})\n    c.create_table(\"df\", df)\n    c.create_table(\"df2\", df2)\n\n    query = \"SELECT SUM(df.a) as sum_a, df2.b FROM df INNER JOIN df2 ON df.c=df2.c GROUP BY df2.b\"\n\n    explain_string = c.explain(query)\n    assert (\"Projection: df.a, df2.b\" in explain_string) or (\n        \"Projection: df2.b, df.a\" in explain_string\n    )\n\n    result_df = c.sql(query)\n    expected_df = pd.DataFrame({\"sum_a\": [24, 24, 12], \"b\": [1, 2, 3]})\n    assert_eq(result_df, expected_df)\n\n\ndef test_join_reorder(c):\n    df = pd.DataFrame({\"a1\": [1, 2, 3, 4, 5] * 2, \"a2\": [1, 1, 2, 2, 2] * 2})\n    df2 = pd.DataFrame({\"b1\": [1, 1, 2, 2, 3] * 10000, \"b2\": [2, 2, 2, 2, 2] * 10000})\n    df3 = pd.DataFrame({\"c2\": [1, 1, 2, 2, 3], \"c3\": [2, 3, 4, 5, 6]})\n    c.create_table(\"a\", df, statistics=Statistics(10))\n    c.create_table(\"b\", df2, statistics=Statistics(50000))\n    c.create_table(\"c\", df3, statistics=Statistics(5))\n\n    # Basic join reorder test\n    query = \"\"\"\n        SELECT a1, b2, c3\n        FROM a, b, c\n        WHERE b1 < 3 AND c3 < 5 AND a1 = b1 AND b2 = c2\n    \"\"\"\n\n    explain_string = c.explain(query)\n\n    first_join = \"Inner Join: b.b2 = c.c2\"\n    second_join = \"Inner Join: b.b1 = a.a1\"\n    \"\"\"\n    LogicalPlan is expected to look something like:\n\n    Limit: skip=0, fetch=10\n    Projection: a.a1, b.b2, c.c3\n        Inner Join: b.b1 = a.a1\n        Projection: b.b1, b.b2, c.c3\n            Inner Join: b.b2 = c.c2\n            Projection: b.b1, b.b2\n                TableScan: b projection=[b1, b2], full_filters=[b.b1 < Int64(3), b.b2 IS NOT NULL, b.b1 IS NOT NULL]\n            Projection: c.c2, c.c3\n                TableScan: c projection=[c2, c3], full_filters=[c.c3 < Int64(5), c.c2 IS NOT NULL]\n        Projection: a.a1\n            TableScan: a projection=[a1], full_filters=[a.a1 < Int64(3), a.a1 IS NOT NULL]\n\n    So the a-b join is expected to appear earlier in the string than the b-c join\n    \"\"\"\n    assert first_join in explain_string and second_join in explain_string\n    assert explain_string.index(second_join) < explain_string.index(first_join)\n\n    result_df = c.sql(query)\n    merged_df = df.merge(df2, left_on=\"a1\", right_on=\"b1\").merge(\n        df3, left_on=\"b2\", right_on=\"c2\"\n    )\n    expected_df = merged_df[(merged_df[\"b1\"] < 3) & (merged_df[\"c3\"] < 5)][\n        [\"a1\", \"b2\", \"c3\"]\n    ]\n\n    assert_eq(result_df, expected_df, check_index=False)\n\n    # By default, join reordering should NOT reorder unfiltered dimension tables\n    query = \"\"\"\n        SELECT a1, b2, c3\n        FROM a, b, c\n        WHERE a1 = b1 AND b2 = c2\n    \"\"\"\n\n    explain_string = c.explain(query)\n\n    first_join = \"Inner Join: b.b1 = a.a1\"\n    second_join = \"Inner Join: b.b2 = c.c2\"\n    assert first_join in explain_string and second_join in explain_string\n    assert explain_string.index(second_join) < explain_string.index(first_join)\n\n    result_df = c.sql(query)\n    expected_df = df.merge(df2, left_on=\"a1\", right_on=\"b1\").merge(\n        df3, left_on=\"b2\", right_on=\"c2\"\n    )[[\"a1\", \"b2\", \"c3\"]]\n\n    assert_eq(result_df, expected_df, check_index=False)\n\n\ndef check_broadcast_join(df, val, raises=False):\n    \"\"\"\n    Check that the broadcast join is correctly set in the Dask layer or expression graph\n\n    Parameters\n    ----------\n    df : DataFrame\n        The DataFrame to check\n    val : bool or float\n        The expected value of the broadcast join\n    raises : bool, optional\n        Whether the legacy Dask check should raise an error if the broadcast join is not set\n    \"\"\"\n    if dd._dask_expr_enabled():\n        from dask_expr._merge import Merge\n\n        merge_ops = [op for op in df.expr.find_operations(Merge)]\n        assert len(merge_ops) == 1\n        assert merge_ops[0].broadcast == val\n    else:\n        with pytest.raises(KeyError) if raises else nullcontext():\n            assert hlg_layer(df.dask, \"bcast-join\")\n\n\n@pytest.mark.parametrize(\"gpu\", [False, pytest.param(True, marks=pytest.mark.gpu)])\ndef test_broadcast_join(c, client, gpu):\n    df1 = dd.from_pandas(\n        pd.DataFrame({\"user_id\": [1, 2, 3, 4], \"b\": [5, 6, 7, 8]}),\n        npartitions=2,\n    )\n    df2 = dd.from_pandas(\n        pd.DataFrame({\"user_id\": [1, 2, 3, 4] * 4, \"c\": [5, 6, 7, 8] * 4}),\n        npartitions=8,\n    )\n    c.create_table(\"df1\", df1, gpu=gpu)\n    c.create_table(\"df2\", df2, gpu=gpu)\n\n    query_string = \"\"\"\n    SELECT df1.user_id as user_id, b, c\n    FROM df1, df2\n    WHERE df1.user_id = df2.user_id\n    \"\"\"\n\n    expected_df = df1.merge(df2, on=\"user_id\", how=\"inner\")\n\n    res_df = c.sql(query_string, config_options={\"sql.join.broadcast\": True})\n    check_broadcast_join(res_df, True)\n    assert_eq(\n        res_df,\n        expected_df,\n        check_divisions=False,\n        check_index=False,\n        scheduler=\"distributed\",\n    )\n\n    res_df = c.sql(query_string, config_options={\"sql.join.broadcast\": 1.0})\n    check_broadcast_join(res_df, 1.0)\n    assert_eq(\n        res_df,\n        expected_df,\n        check_divisions=False,\n        check_index=False,\n        scheduler=\"distributed\",\n    )\n\n    res_df = c.sql(query_string, config_options={\"sql.join.broadcast\": 0.5})\n    check_broadcast_join(res_df, 0.5, raises=True)\n    assert_eq(res_df, expected_df, check_index=False, scheduler=\"distributed\")\n\n    res_df = c.sql(query_string, config_options={\"sql.join.broadcast\": False})\n    check_broadcast_join(res_df, False, raises=True)\n    assert_eq(res_df, expected_df, check_index=False, scheduler=\"distributed\")\n\n    res_df = c.sql(query_string, config_options={\"sql.join.broadcast\": None})\n    check_broadcast_join(res_df, None, raises=True)\n    assert_eq(res_df, expected_df, check_index=False, scheduler=\"distributed\")\n\n\n@pytest.mark.gpu\ndef test_null_key_join(c):\n    df1 = pd.DataFrame({\"a\": [None, None, None, None, None, 1]})\n    df2 = pd.DataFrame({\"b\": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]})\n    c.create_table(\"df1\", df1, gpu=True)\n    c.create_table(\"df2\", df2, gpu=True)\n\n    result_df = c.sql(\n        \"SELECT * FROM (select * from df1 limit 5) JOIN (select * from df2 limit 5) ON a=b\"\n    )\n    expected_df = pd.DataFrame({\"a\": [], \"b\": []})\n\n    assert_eq(result_df, expected_df)\n"
  },
  {
    "path": "tests/integration/test_model.py",
    "content": "import os\nimport pickle\nimport sys\n\nimport joblib\nimport pandas as pd\nimport pytest\nfrom packaging.version import parse as parseVersion\n\nfrom tests.utils import assert_eq\n\ntry:\n    import cuml\n    import dask_cudf\n    import xgboost\nexcept ImportError:\n    cuml = None\n    xgboost = None\n    dask_cudf = None\n\nsklearn = pytest.importorskip(\"sklearn\")\n\nSKLEARN_EQ_140 = parseVersion(sklearn.__version__) == parseVersion(\"1.4.0\")\n\n\ndef check_trained_model(c, model_name=\"my_model\", df_name=\"timeseries\"):\n    sql = f\"\"\"\n    SELECT * FROM PREDICT(\n        MODEL {model_name},\n        SELECT x, y FROM {df_name}\n    )\n    \"\"\"\n\n    tables_before = c.schema[\"root\"].tables.keys()\n    result_df = c.sql(sql).compute()\n\n    # assert that there are no additional tables in context from prediction\n    assert tables_before == c.schema[\"root\"].tables.keys()\n    assert \"target\" in result_df.columns\n    assert len(result_df[\"target\"]) > 0\n\n\n@pytest.mark.parametrize(\n    \"gpu_client\", [False, pytest.param(True, marks=pytest.mark.gpu)], indirect=True\n)\ndef test_training_and_prediction(c, gpu_client):\n    gpu = \"CUDA\" in str(gpu_client.cluster)\n    timeseries = \"gpu_timeseries\" if gpu else \"timeseries\"\n\n    # cuML does not have a GradientBoostingClassifier\n    if not gpu:\n        c.sql(\n            \"\"\"\n            CREATE MODEL my_model WITH (\n                model_class = 'GradientBoostingClassifier',\n                wrap_predict = True,\n                target_column = 'target'\n            ) AS (\n                SELECT x, y, x*y > 0 AS target\n                FROM timeseries\n                LIMIT 100\n            )\n        \"\"\"\n        )\n        check_trained_model(c)\n\n        c.sql(\n            f\"\"\"\n            CREATE OR REPLACE MODEL my_model WITH (\n                model_class = 'LogisticRegression',\n                wrap_predict = True,\n                wrap_fit = False,\n                target_column = 'target'\n            ) AS (\n                SELECT x, y, x*y > 0 AS target\n                FROM {timeseries}\n            )\n        \"\"\"\n        )\n        check_trained_model(c, df_name=timeseries)\n\n    c.sql(\n        f\"\"\"\n        CREATE OR REPLACE MODEL my_model WITH (\n            model_class = 'LinearRegression',\n            target_column = 'target'\n        ) AS (\n            SELECT x, y, x*y AS target\n            FROM {timeseries}\n        )\n    \"\"\"\n    )\n    check_trained_model(c, df_name=timeseries)\n\n\n# TODO: investigate deadlocks on GPU\n@pytest.mark.xfail(\n    sys.platform in (\"darwin\", \"win32\"),\n    reason=\"Intermittent failures on macOS/Windows\",\n    strict=False,\n)\n@pytest.mark.parametrize(\n    \"gpu_client\",\n    [\n        False,\n        pytest.param(\n            True, marks=(pytest.mark.gpu, pytest.mark.skip(reason=\"Deadlocks on GPU\"))\n        ),\n    ],\n    indirect=True,\n)\ndef test_xgboost_training_prediction(c, gpu_client):\n    gpu = \"CUDA\" in str(gpu_client.cluster)\n    timeseries = \"gpu_timeseries\" if gpu else \"timeseries\"\n\n    # TODO: XGBClassifiers error on GPU\n    if not gpu:\n        c.sql(\n            \"\"\"\n        CREATE OR REPLACE MODEL my_model WITH (\n            model_class = 'DaskXGBClassifier',\n            target_column = 'target'\n        ) AS (\n            SELECT x, y, x*y > 0  AS target\n            FROM timeseries\n            LIMIT 100\n        )\n        \"\"\"\n        )\n        check_trained_model(c)\n\n        c.sql(\n            \"\"\"\n        CREATE OR REPLACE MODEL my_model WITH (\n            model_class = 'XGBClassifier',\n            target_column = 'target'\n        ) AS (\n            SELECT x, y, x*y > 0  AS target\n            FROM timeseries\n            LIMIT 100\n        )\n        \"\"\"\n        )\n        check_trained_model(c)\n\n    # For GPU tests, set tree_method = 'gpu_hist'\n    tree_method = \"gpu_hist\" if gpu else \"hist\"\n\n    c.sql(\n        f\"\"\"\n    CREATE OR REPLACE MODEL my_model WITH (\n        model_class = 'DaskXGBRegressor',\n        target_column = 'target',\n        tree_method = '{tree_method}'\n    ) AS (\n        SELECT x, y, x*y  AS target\n        FROM {timeseries}\n    )\n    \"\"\"\n    )\n    check_trained_model(c, df_name=timeseries)\n\n    c.sql(\n        f\"\"\"\n    CREATE OR REPLACE MODEL my_model WITH (\n        model_class = 'XGBRegressor',\n        wrap_predict = True,\n        target_column = 'target',\n        tree_method = '{tree_method}'\n    ) AS (\n        SELECT x, y, x*y  AS target\n        FROM {timeseries}\n    )\n    \"\"\"\n    )\n    check_trained_model(c, df_name=timeseries)\n\n\n@pytest.mark.parametrize(\n    \"gpu_client\", [False, pytest.param(True, marks=pytest.mark.gpu)], indirect=True\n)\ndef test_clustering_and_prediction(c, gpu_client):\n    gpu = \"CUDA\" in str(gpu_client.cluster)\n    timeseries = \"gpu_timeseries\" if gpu else \"timeseries\"\n\n    c.sql(\n        f\"\"\"\n        CREATE MODEL my_model WITH (\n            model_class = 'KMeans'\n        ) AS (\n            SELECT x, y\n            FROM {timeseries}\n            LIMIT 100\n        )\n    \"\"\"\n    )\n    check_trained_model(c, df_name=timeseries)\n\n\ndef test_create_model_with_prediction(c):\n    c.sql(\n        \"\"\"\n        CREATE MODEL my_model1 WITH (\n            model_class = 'GradientBoostingClassifier',\n            wrap_predict = True,\n            target_column = 'target'\n        ) AS (\n            SELECT x, y, x*y > 0 AS target\n            FROM timeseries\n            LIMIT 100\n        )\n    \"\"\"\n    )\n\n    c.sql(\n        \"\"\"\n        CREATE MODEL my_model2 WITH (\n            model_class = 'GradientBoostingClassifier',\n            wrap_predict = True,\n            target_column = 'target'\n        ) AS (\n            SELECT * FROM PREDICT (\n                MODEL my_model1,\n                SELECT x, y FROM timeseries LIMIT 100\n            )\n        )\n    \"\"\"\n    )\n\n    check_trained_model(c, \"my_model2\")\n\n\ndef test_iterative_and_prediction(c):\n    c.sql(\n        \"\"\"\n        CREATE MODEL my_model WITH (\n            model_class = 'SGDClassifier',\n            wrap_fit = True,\n            target_column = 'target',\n            fit_kwargs = ( classes = ARRAY [0, 1] )\n        ) AS (\n            SELECT x, y, x*y > 0 AS target\n            FROM timeseries\n            LIMIT 100\n        )\n    \"\"\"\n    )\n    check_trained_model(c)\n\n\ndef test_show_models(c):\n    c.sql(\n        \"\"\"\n        CREATE MODEL my_model1 WITH (\n            model_class = 'GradientBoostingClassifier',\n            wrap_predict = True,\n            target_column = 'target'\n        ) AS (\n            SELECT x, y, x*y > 0 AS target\n            FROM timeseries\n            LIMIT 100\n        )\n    \"\"\"\n    )\n\n    c.sql(\n        \"\"\"\n        CREATE MODEL my_model2 WITH (\n            model_class = 'KMeans'\n        ) AS (\n            SELECT x, y\n            FROM timeseries\n            LIMIT 100\n        )\n    \"\"\"\n    )\n\n    c.sql(\n        \"\"\"\n        CREATE MODEL my_model3 WITH (\n            model_class = 'SGDClassifier',\n            wrap_fit = True,\n            target_column = 'target',\n            fit_kwargs = ( classes = ARRAY [0, 1] )\n        ) AS (\n            SELECT x, y, x*y > 0 AS target\n            FROM timeseries\n            LIMIT 100\n        )\n    \"\"\"\n    )\n\n    result = c.sql(\"SHOW MODELS\")\n    expected = pd.DataFrame([\"my_model1\", \"my_model2\", \"my_model3\"], columns=[\"Models\"])\n\n    assert_eq(result, expected)\n\n\ndef test_wrong_training_or_prediction(c):\n    with pytest.raises(KeyError):\n        c.sql(\n            \"\"\"\n            SELECT * FROM PREDICT(\n            MODEL my_model,\n            SELECT x, y FROM timeseries\n        )\n        \"\"\"\n        )\n\n    with pytest.raises(ValueError):\n        c.sql(\n            \"\"\"\n            CREATE MODEL my_model WITH (\n                target_column = 'target'\n            ) AS (\n                SELECT x, y, x*y > 0 AS target\n                FROM timeseries\n                LIMIT 100\n            )\n        \"\"\"\n        )\n\n    with pytest.raises(ImportError):\n        c.sql(\n            \"\"\"\n            CREATE MODEL my_model WITH (\n                model_class = 'that.is.not.a.python.class',\n                target_column = 'target'\n            ) AS (\n                SELECT x, y, x*y > 0 AS target\n                FROM timeseries\n                LIMIT 100\n            )\n        \"\"\"\n        )\n\n\ndef test_correct_argument_passing(c):\n    c.sql(\n        \"\"\"\n        CREATE MODEL my_model WITH (\n            model_class = 'mock.MagicMock',\n            target_column = 'target',\n            fit_kwargs = (\n                single_quoted_string = 'hello',\n                double_quoted_string = \"hi\",\n                integer = -300,\n                float = 23.45,\n                boolean = False,\n                array = ARRAY [ 1, 2 ],\n                dict = MAP [ 'a', 1 ],\n                set = MULTISET [ 1, 1, 2, 3 ]\n            )\n        ) AS (\n            SELECT x, y, x*y > 0 AS target\n            FROM timeseries\n            LIMIT 100\n        )\n    \"\"\"\n    )\n\n    mocked_model, columns = c.schema[c.schema_name].models[\"my_model\"]\n    assert list(columns) == [\"x\", \"y\"]\n\n    fit_function = mocked_model.fit\n\n    fit_function.assert_called_once()\n    call_kwargs = fit_function.call_args.kwargs\n    assert call_kwargs == dict(\n        single_quoted_string=\"hello\",\n        double_quoted_string=\"hi\",\n        integer=-300,\n        float=23.45,\n        boolean=False,\n        array=[1, 2],\n        dict={\"a\": 1},\n        set={1, 2, 3},\n    )\n\n\ndef test_replace_and_error(c):\n    c.sql(\n        \"\"\"\n        CREATE MODEL my_model WITH (\n            model_class = 'mock.MagicMock',\n            target_column = 'target'\n        ) AS (\n            SELECT x, y, x*y > 0 AS target\n            FROM timeseries\n            LIMIT 100\n        )\n    \"\"\"\n    )\n\n    first_mock, _ = c.schema[c.schema_name].models[\"my_model\"]\n\n    with pytest.raises(RuntimeError):\n        c.sql(\n            \"\"\"\n            CREATE MODEL my_model WITH (\n                model_class = 'mock.MagicMock',\n                target_column = 'target'\n            ) AS (\n                SELECT x, y, x*y > 0 AS target\n                FROM timeseries\n                LIMIT 100\n            )\n        \"\"\"\n        )\n\n    c.sql(\n        \"\"\"\n        CREATE MODEL IF NOT EXISTS my_model WITH (\n            model_class = 'mock.MagicMock',\n            target_column = 'target'\n        ) AS (\n            SELECT x, y, x*y > 0 AS target\n            FROM timeseries\n            LIMIT 100\n        )\n    \"\"\"\n    )\n\n    assert c.schema[c.schema_name].models[\"my_model\"][0] == first_mock\n\n    c.sql(\n        \"\"\"\n        CREATE OR REPLACE MODEL my_model WITH (\n            model_class = 'mock.MagicMock',\n            target_column = 'target'\n        ) AS (\n            SELECT x, y, x*y > 0 AS target\n            FROM timeseries\n            LIMIT 100\n        )\n    \"\"\"\n    )\n\n    assert c.schema[c.schema_name].models[\"my_model\"][0] != first_mock\n    second_mock, _ = c.schema[c.schema_name].models[\"my_model\"]\n\n    c.sql(\"DROP MODEL my_model\")\n\n    c.sql(\n        \"\"\"\n        CREATE MODEL IF NOT EXISTS my_model WITH (\n            model_class = 'mock.MagicMock',\n            target_column = 'target'\n        ) AS (\n            SELECT x, y, x*y > 0 AS target\n            FROM timeseries\n            LIMIT 100\n        )\n    \"\"\"\n    )\n\n    assert c.schema[c.schema_name].models[\"my_model\"][0] != second_mock\n\n\ndef test_drop_model(c):\n    with pytest.raises(RuntimeError):\n        c.sql(\"DROP MODEL my_model\")\n\n    c.sql(\"DROP MODEL IF EXISTS my_model\")\n\n    c.sql(\n        \"\"\"\n        CREATE MODEL IF NOT EXISTS my_model WITH (\n            model_class = 'mock.MagicMock',\n            target_column = 'target'\n        ) AS (\n            SELECT x, y, x*y > 0 AS target\n            FROM timeseries\n            LIMIT 100\n        )\n    \"\"\"\n    )\n\n    c.sql(\"DROP MODEL IF EXISTS my_model\")\n\n    assert \"my_model\" not in c.schema[c.schema_name].models\n\n\ndef test_describe_model(c):\n    c.sql(\n        \"\"\"\n        CREATE MODEL ex_describe_model WITH (\n            model_class = 'GradientBoostingClassifier',\n            wrap_predict = True,\n            target_column = 'target'\n        ) AS (\n            SELECT x, y, x*y > 0 AS target\n            FROM timeseries\n            LIMIT 100\n        )\n    \"\"\"\n    )\n\n    model, training_columns = c.schema[c.schema_name].models[\"ex_describe_model\"]\n    expected_dict = model.get_params()\n    expected_dict[\"training_columns\"] = training_columns.tolist()\n    # hack for converting model class into string\n    expected_series = (\n        pd.DataFrame.from_dict(expected_dict, orient=\"index\", columns=[\"Params\"])[\n            \"Params\"\n        ]\n        .apply(lambda x: str(x))\n        .sort_index()\n    )\n    actual_series = c.sql(\"DESCRIBE MODEL ex_describe_model\")\n    actual_series = actual_series[\"Params\"].apply(\n        lambda x: str(x), meta=actual_series[\"Params\"]\n    )\n\n    assert_eq(expected_series, actual_series)\n\n    with pytest.raises(RuntimeError):\n        c.sql(\"DESCRIBE MODEL undefined_model\")\n\n\ndef test_export_model(c, tmpdir):\n    with pytest.raises(RuntimeError):\n        c.sql(\n            \"\"\"EXPORT MODEL not_available_model with (\n                format ='pickle',\n                location = '/tmp/model.pkl'\n            )\"\"\"\n        )\n\n    c.sql(\n        \"\"\"\n        CREATE MODEL IF NOT EXISTS my_model WITH (\n            model_class = 'GradientBoostingClassifier',\n            target_column = 'target'\n        ) AS (\n            SELECT x, y, x*y > 0 AS target\n            FROM timeseries\n            LIMIT 100\n        )\n    \"\"\"\n    )\n\n    # Happy flow\n    temporary_file = os.path.join(tmpdir, \"pickle_model.pkl\")\n    c.sql(\n        \"\"\"EXPORT MODEL my_model with (\n            format ='pickle',\n            location = '{}'\n        )\"\"\".format(\n            temporary_file\n        )\n    )\n\n    assert (\n        pickle.load(open(str(temporary_file), \"rb\")).estimator.__class__.__name__\n        == \"GradientBoostingClassifier\"\n    )\n\n    temporary_file = os.path.join(tmpdir, \"model.joblib\")\n    c.sql(\n        \"\"\"EXPORT MODEL my_model with (\n            format ='joblib',\n            location = '{}'\n        )\"\"\".format(\n            temporary_file\n        )\n    )\n\n    assert (\n        joblib.load(str(temporary_file)).estimator.__class__.__name__\n        == \"GradientBoostingClassifier\"\n    )\n\n    with pytest.raises(NotImplementedError):\n        temporary_dir = os.path.join(tmpdir, \"model.onnx\")\n        c.sql(\n            \"\"\"EXPORT MODEL my_model with (\n                format ='onnx',\n                location = '{}'\n            )\"\"\".format(\n                temporary_dir\n            )\n        )\n\n\ndef test_mlflow_export(c, tmpdir):\n    # Test only when mlflow was installed\n    mlflow = pytest.importorskip(\"mlflow\", reason=\"mlflow not installed\")\n\n    c.sql(\n        \"\"\"\n        CREATE MODEL IF NOT EXISTS my_model WITH (\n            model_class = 'GradientBoostingClassifier',\n            target_column = 'target'\n        ) AS (\n            SELECT x, y, x*y > 0 AS target\n            FROM timeseries\n            LIMIT 100\n        )\n    \"\"\"\n    )\n\n    temporary_dir = os.path.join(tmpdir, \"mlflow\")\n    c.sql(\n        \"\"\"EXPORT MODEL my_model with (\n            format ='mlflow',\n            location = '{}'\n        )\"\"\".format(\n            temporary_dir\n        )\n    )\n\n    # for sklearn compatible model\n    assert (\n        mlflow.sklearn.load_model(str(temporary_dir)).estimator.__class__.__name__\n        == \"GradientBoostingClassifier\"\n    )\n\n    # test for non sklearn compatible model\n    c.sql(\n        \"\"\"\n        CREATE MODEL IF NOT EXISTS non_sklearn_model WITH (\n            model_class = 'mock.MagicMock',\n            target_column = 'target'\n        ) AS (\n            SELECT x, y, x*y > 0 AS target\n            FROM timeseries\n            LIMIT 100\n        )\n    \"\"\"\n    )\n\n    temporary_dir = os.path.join(tmpdir, \"non_sklearn\")\n    with pytest.raises(NotImplementedError):\n        c.sql(\n            \"\"\"EXPORT MODEL non_sklearn_model with (\n                format ='mlflow',\n                location = '{}'\n            )\"\"\".format(\n                temporary_dir\n            )\n        )\n\n\n@pytest.mark.xfail(\n    sys.platform == \"darwin\", reason=\"Intermittent socket errors on macOS\", strict=False\n)\ndef test_mlflow_export_xgboost(c, client, tmpdir):\n    # Test only when mlflow & xgboost was installed\n    mlflow = pytest.importorskip(\"mlflow\", reason=\"mlflow not installed\")\n    xgboost = pytest.importorskip(\"xgboost\", reason=\"xgboost not installed\")\n\n    c.sql(\n        \"\"\"\n        CREATE MODEL IF NOT EXISTS my_model_xgboost WITH (\n            model_class = 'DaskXGBClassifier',\n            target_column = 'target'\n        ) AS (\n            SELECT x, y, x*y > 0 AS target\n            FROM timeseries\n            LIMIT 100\n        )\n    \"\"\"\n    )\n\n    temporary_dir = os.path.join(tmpdir, \"mlflow_xgboost\")\n    c.sql(\n        \"\"\"EXPORT MODEL my_model_xgboost with (\n            format = 'mlflow',\n            location = '{}'\n        )\"\"\".format(\n            temporary_dir\n        )\n    )\n\n    assert (\n        mlflow.sklearn.load_model(str(temporary_dir)).__class__.__name__\n        == \"DaskXGBClassifier\"\n    )\n\n\ndef test_mlflow_export_lightgbm(c, tmpdir):\n    # Test only when mlflow & lightgbm was installed\n    mlflow = pytest.importorskip(\"mlflow\", reason=\"mlflow not installed\")\n    lightgbm = pytest.importorskip(\"lightgbm\", reason=\"lightgbm not installed\")\n\n    c.sql(\n        \"\"\"\n        CREATE MODEL IF NOT EXISTS my_model_lightgbm WITH (\n            model_class = 'LGBMClassifier',\n            target_column = 'target'\n        ) AS (\n            SELECT x, y, x*y > 0 AS target\n            FROM timeseries\n            LIMIT 100\n        )\n    \"\"\"\n    )\n\n    temporary_dir = os.path.join(tmpdir, \"mlflow_lightgbm\")\n    c.sql(\n        \"\"\"EXPORT MODEL my_model_lightgbm with (\n            format = 'mlflow',\n            location = '{}'\n        )\"\"\".format(\n            temporary_dir\n        )\n    )\n\n    assert (\n        mlflow.sklearn.load_model(str(temporary_dir)).__class__.__name__\n        == \"LGBMClassifier\"\n    )\n\n\ndef test_ml_experiment(c, client):\n    with pytest.raises(\n        ValueError,\n        match=\"Parameters must include a 'model_class' \" \"or 'automl_class' parameter.\",\n    ):\n        c.sql(\n            \"\"\"\n        CREATE EXPERIMENT my_exp WITH (\n            experiment_class = 'GridSearchCV',\n            tune_parameters = (n_estimators = ARRAY [16, 32, 2],learning_rate = ARRAY [0.1,0.01,0.001],\n                               max_depth = ARRAY [3,4,5,10]),\n            target_column = 'target'\n        ) AS (\n            SELECT x, y, x*y > 0 AS target\n            FROM timeseries\n            LIMIT 100\n        )\n        \"\"\"\n        )\n\n    with pytest.raises(\n        ValueError,\n        match=\"Parameters must include a 'experiment_class' \"\n        \"parameter for tuning GradientBoostingClassifier.\",\n    ):\n        c.sql(\n            \"\"\"\n        CREATE EXPERIMENT my_exp WITH (\n            model_class = 'GradientBoostingClassifier',\n            tune_parameters = (n_estimators = ARRAY [16, 32, 2],learning_rate = ARRAY [0.1,0.01,0.001],\n                               max_depth = ARRAY [3,4,5,10]),\n            target_column = 'target'\n        ) AS (\n            SELECT x, y, x*y > 0 AS target\n            FROM timeseries\n            LIMIT 100\n        )\n        \"\"\"\n        )\n\n    with pytest.raises(\n        ValueError,\n        match=\"Can not import model that.is.not.a.python.class. Make sure you spelled \"\n        \"it correctly and have installed all packages.\",\n    ):\n        c.sql(\n            \"\"\"\n            CREATE EXPERIMENT IF NOT EXISTS my_exp WITH (\n            model_class = 'that.is.not.a.python.class',\n            experiment_class = 'GridSearchCV',\n            tune_parameters = (n_estimators = ARRAY [16, 32, 2],learning_rate = ARRAY [0.1,0.01,0.001],\n                               max_depth = ARRAY [3,4,5,10]),\n            target_column = 'target'\n        ) AS (\n            SELECT x, y, x*y > 0 AS target\n            FROM timeseries\n            LIMIT 100\n        )\n        \"\"\"\n        )\n\n    with pytest.raises(\n        ValueError,\n        match=\"Can not import tuner that.is.not.a.python.class. Make sure you spelled \"\n        \"it correctly and have installed all packages.\",\n    ):\n        c.sql(\n            \"\"\"\n            CREATE EXPERIMENT IF NOT EXISTS my_exp WITH (\n            model_class =  'GradientBoostingClassifier',\n            experiment_class = 'that.is.not.a.python.class',\n            tune_parameters = (n_estimators = ARRAY [16, 32, 2],learning_rate = ARRAY [0.1,0.01,0.001],\n                               max_depth = ARRAY [3,4,5,10]),\n            target_column = 'target'\n        ) AS (\n            SELECT x, y, x*y > 0 AS target\n            FROM timeseries\n            LIMIT 100\n        )\n        \"\"\"\n        )\n\n    with pytest.raises(\n        ValueError,\n        match=\"Can not import automl model that.is.not.a.python.class. \"\n        \"Make sure you spelled \"\n        \"it correctly and have installed all packages.\",\n    ):\n        c.sql(\n            \"\"\"\n            CREATE EXPERIMENT my_exp64 WITH (\n                automl_class = 'that.is.not.a.python.class',\n                automl_kwargs = (\n                    population_size = 2,\n                    generations = 2,\n                    cv = 2,\n                    n_jobs = -1,\n                    use_dask = True,\n                    max_eval_time_mins = 1\n                ),\n                target_column = 'target'\n            ) AS (\n                SELECT x, y, x*y > 0 AS target\n                FROM timeseries\n                LIMIT 100\n            )\n            \"\"\"\n        )\n\n    # happy flow\n    c.sql(\n        \"\"\"\n        CREATE EXPERIMENT my_exp WITH (\n        model_class = 'GradientBoostingClassifier',\n        experiment_class = 'GridSearchCV',\n        tune_parameters = (n_estimators = ARRAY [16, 32, 2],learning_rate = ARRAY [0.1,0.01,0.001],\n                           max_depth = ARRAY [3,4,5,10]),\n        experiment_kwargs = (n_jobs = -1),\n        target_column = 'target'\n    ) AS (\n            SELECT x, y, x*y > 0 AS target\n            FROM timeseries\n            LIMIT 100\n        )\n        \"\"\"\n    )\n    assert \"my_exp\" in c.schema[c.schema_name].models, \"Best model was not registered\"\n    check_trained_model(c, \"my_exp\")\n\n    with pytest.raises(RuntimeError):\n        # my_exp already exists\n        c.sql(\n            \"\"\"\n            CREATE EXPERIMENT my_exp WITH (\n            model_class = 'GradientBoostingClassifier',\n            experiment_class = 'GridSearchCV',\n            tune_parameters = (n_estimators = ARRAY [16, 32, 2],learning_rate = ARRAY [0.1,0.01,0.001],\n                               max_depth = ARRAY [3,4,5,10]),\n            target_column = 'target'\n        ) AS (\n            SELECT x, y, x*y > 0 AS target\n            FROM timeseries\n            LIMIT 100\n        )\n            \"\"\"\n        )\n\n    c.sql(\n        \"\"\"\n        CREATE EXPERIMENT IF NOT EXISTS my_exp WITH (\n            model_class = 'GradientBoostingClassifier',\n            experiment_class = 'GridSearchCV',\n            tune_parameters = (n_estimators = ARRAY [16, 32, 2],learning_rate = ARRAY [0.1,0.01,0.001],\n                               max_depth = ARRAY [3,4,5,10]),\n            experiment_kwargs = (n_jobs = -1),\n            target_column = 'target'\n        ) AS (\n            SELECT x, y, x*y > 0 AS target\n            FROM timeseries\n            LIMIT 100\n        )\n        \"\"\"\n    )\n\n    c.sql(\n        \"\"\"\n        CREATE OR REPLACE EXPERIMENT my_exp WITH (\n            model_class = 'GradientBoostingClassifier',\n            experiment_class = 'GridSearchCV',\n            tune_parameters = (n_estimators = ARRAY [16, 32, 2],learning_rate = ARRAY [0.1,0.01,0.001],\n                               max_depth = ARRAY [3,4,5,10]),\n            experiment_kwargs = (n_jobs = -1),\n            target_column = 'target'\n        ) AS (\n            SELECT x, y, x*y > 0 AS target\n            FROM timeseries\n            LIMIT 100\n        )\n        \"\"\"\n    )\n\n    with pytest.raises(\n        ValueError,\n        match=\"Unsupervised Algorithm cannot be tuned Automatically,\"\n        \"Consider providing 'target column'\",\n    ):\n        c.sql(\n            \"\"\"\n            CREATE EXPERIMENT my_exp1 WITH (\n                model_class = 'KMeans',\n                experiment_class = 'RandomizedSearchCV',\n                tune_parameters = (n_clusters = ARRAY [3,4,16],tol = ARRAY [0.1,0.01,0.001],\n                                   max_iter = ARRAY [3,4,5,10])\n            ) AS (\n                SELECT x, y\n                FROM timeseries\n                LIMIT 100\n            )\n            \"\"\"\n        )\n\n\n@pytest.mark.xfail(\n    reason=\"tpot is broken with sklearn==1.4.0\", condition=SKLEARN_EQ_140\n)\ndef test_experiment_automl_classifier(c, client):\n    tpot = pytest.importorskip(\"tpot\", reason=\"tpot not installed\")\n\n    c.sql(\n        \"\"\"\n        CREATE EXPERIMENT my_automl_exp1 WITH (\n            automl_class = 'tpot.TPOTClassifier',\n            automl_kwargs = (population_size=2, generations=2, cv=2, n_jobs=-1),\n            target_column = 'target'\n        ) AS (\n            SELECT x, y, x*y > 0 AS target\n            FROM timeseries\n            LIMIT 100\n        )\n        \"\"\"\n    )\n\n    assert (\n        \"my_automl_exp1\" in c.schema[c.schema_name].models\n    ), \"Best model was not registered\"\n\n    check_trained_model(c, \"my_automl_exp1\")\n\n\n@pytest.mark.xfail(\n    reason=\"tpot is broken with sklearn==1.4.0\", condition=SKLEARN_EQ_140\n)\ndef test_experiment_automl_regressor(c, client):\n    tpot = pytest.importorskip(\"tpot\", reason=\"tpot not installed\")\n\n    # test regressor\n    c.sql(\n        \"\"\"\n        CREATE EXPERIMENT my_automl_exp2 WITH (\n            automl_class = 'tpot.TPOTRegressor',\n            automl_kwargs = (population_size=2,\n            generations=2,\n            cv=2,\n            n_jobs=-1,\n            max_eval_time_mins=1),\n\n            target_column = 'target'\n        ) AS (\n            SELECT x, y, x*y  AS target\n            FROM timeseries\n            LIMIT 100\n        )\n        \"\"\"\n    )\n\n    assert (\n        \"my_automl_exp2\" in c.schema[c.schema_name].models\n    ), \"Best model was not registered\"\n\n    check_trained_model(c, \"my_automl_exp2\")\n\n\ndef test_predict_with_nullable_types(c):\n    df = pd.DataFrame(\n        {\n            \"rough_day_of_year\": [0, 1, 2, 3],\n            \"prev_day_inches_rained\": [0.0, 1.0, 2.0, 3.0],\n            \"rained\": [False, False, False, True],\n        }\n    )\n    c.create_table(\"train_set\", df)\n\n    model_class = \"'LogisticRegression'\"\n\n    c.sql(\n        f\"\"\"\n        CREATE OR REPLACE MODEL model WITH (\n            model_class = {model_class},\n            wrap_predict = True,\n            wrap_fit = False,\n            target_column = 'rained'\n        ) AS (\n            SELECT *\n            FROM train_set\n        )\n        \"\"\"\n    )\n\n    expected = c.sql(\n        \"\"\"\n        SELECT * FROM PREDICT(\n            MODEL model,\n            SELECT * FROM train_set\n        )\n        \"\"\"\n    )\n\n    df = pd.DataFrame(\n        {\n            \"rough_day_of_year\": pd.Series([0, 1, 2, 3], dtype=\"Int32\"),\n            \"prev_day_inches_rained\": pd.Series([0.0, 1.0, 2.0, 3.0], dtype=\"Float32\"),\n            \"rained\": pd.Series([False, False, False, True]),\n        }\n    )\n    c.create_table(\"train_set\", df)\n\n    c.sql(\n        f\"\"\"\n        CREATE OR REPLACE MODEL model WITH (\n            model_class = {model_class},\n            wrap_predict = True,\n            wrap_fit = False,\n            target_column = 'rained'\n        ) AS (\n            SELECT *\n            FROM train_set\n        )\n        \"\"\"\n    )\n\n    result = c.sql(\n        \"\"\"\n        SELECT * FROM PREDICT(\n            MODEL model,\n            SELECT * FROM train_set\n        )\n        \"\"\"\n    )\n\n    assert_eq(\n        expected,\n        result,\n        check_dtype=False,\n    )\n\n\ndef test_predict_with_limit_offset(c):\n    c.sql(\n        \"\"\"\n        CREATE MODEL my_model WITH (\n            model_class = 'GradientBoostingClassifier',\n            wrap_predict = True,\n            target_column = 'target'\n        ) AS (\n            SELECT x, y, x*y > 0 AS target\n            FROM timeseries\n            LIMIT 100\n        )\n    \"\"\"\n    )\n\n    res = c.sql(\n        \"\"\"\n        SELECT * FROM PREDICT (\n            MODEL my_model,\n            SELECT x, y FROM timeseries LIMIT 100 OFFSET 100\n        )\n    \"\"\"\n    )\n\n    res.compute()\n"
  },
  {
    "path": "tests/integration/test_over.py",
    "content": "import pandas as pd\nimport pytest\n\nfrom tests.utils import assert_eq, skipif_dask_expr_enabled\n\n\ndef test_over_with_sorting(c, user_table_1):\n    return_df = c.sql(\n        \"\"\"\n    SELECT\n        user_id,\n        b,\n        ROW_NUMBER() OVER (ORDER BY user_id, b) AS \"R\"\n    FROM user_table_1\n    \"\"\"\n    )\n    expected_df = user_table_1.sort_values([\"user_id\", \"b\"])\n    expected_df[\"R\"] = [1, 2, 3, 4]\n\n    assert_eq(return_df, expected_df, check_dtype=False, check_index=False)\n\n\ndef test_over_with_partitioning(c, user_table_2):\n    return_df = c.sql(\n        \"\"\"\n    SELECT\n        user_id,\n        c,\n        ROW_NUMBER() OVER (PARTITION BY c) AS \"R\"\n    FROM user_table_2\n    ORDER BY user_id, c\n    \"\"\"\n    )\n    expected_df = user_table_2.sort_values([\"user_id\", \"c\"])\n    expected_df[\"R\"] = [1, 1, 1, 1]\n\n    assert_eq(return_df, expected_df, check_dtype=False, check_index=False)\n\n\ndef test_over_with_grouping_and_sort(c, user_table_1):\n    return_df = c.sql(\n        \"\"\"\n    SELECT\n        user_id,\n        b,\n        ROW_NUMBER() OVER (PARTITION BY user_id ORDER BY b) AS \"R\"\n    FROM user_table_1\n    \"\"\"\n    )\n    expected_df = user_table_1.sort_values([\"user_id\", \"b\"])\n    expected_df[\"R\"] = [1, 1, 2, 1]\n\n    assert_eq(return_df, expected_df, check_dtype=False, check_index=False)\n\n\ndef test_over_with_different(c, user_table_1):\n    return_df = c.sql(\n        \"\"\"\n    SELECT\n        user_id,\n        b,\n        ROW_NUMBER() OVER (PARTITION BY user_id ORDER BY b) AS \"R1\",\n        ROW_NUMBER() OVER (ORDER BY user_id, b) AS \"R2\"\n    FROM user_table_1\n    \"\"\"\n    )\n    expected_df = pd.DataFrame(\n        {\n            \"user_id\": user_table_1.user_id,\n            \"b\": user_table_1.b,\n            \"R1\": [2, 1, 1, 1],\n            \"R2\": [3, 1, 2, 4],\n        }\n    )\n\n    assert_eq(return_df, expected_df, check_dtype=False, check_index=False)\n\n\n# TODO: investigate source of window count deadlocks\n@skipif_dask_expr_enabled(\"Deadlocks with query planning enabled\")\ndef test_over_calls(c, user_table_1):\n    return_df = c.sql(\n        \"\"\"\n    SELECT\n        user_id,\n        b,\n        ROW_NUMBER() OVER (PARTITION BY user_id ORDER BY b) AS \"O1\",\n        FIRST_VALUE(user_id*10 - b) OVER (PARTITION BY user_id ORDER BY b) AS \"O2\",\n        -- SINGLE_VALUE(user_id*10 - b) OVER (PARTITION BY user_id ORDER BY b) AS \"O3\",\n        LAST_VALUE(user_id*10 - b) OVER (PARTITION BY user_id ORDER BY b ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS \"O4\",\n        SUM(user_id) OVER (PARTITION BY user_id ORDER BY b) AS \"O5\",\n        AVG(user_id) OVER (PARTITION BY user_id ORDER BY b) AS \"O6\",\n        COUNT(*) OVER (PARTITION BY user_id ORDER BY b) AS \"O7\",\n        COUNT(b) OVER (PARTITION BY user_id ORDER BY b) AS \"O7b\",\n        MAX(b) OVER (PARTITION BY user_id ORDER BY b) AS \"O8\",\n        MIN(b) OVER (PARTITION BY user_id ORDER BY b) AS \"O9\"\n    FROM user_table_1\n    \"\"\"\n    )\n    expected_df = pd.DataFrame(\n        {\n            \"user_id\": user_table_1.user_id,\n            \"b\": user_table_1.b,\n            \"O1\": [2, 1, 1, 1],\n            \"O2\": [19, 7, 19, 27],\n            # \"O3\": [19, 7, 19, 27], https://github.com/dask-contrib/dask-sql/issues/651\n            \"O4\": [17, 7, 17, 27],\n            \"O5\": [4, 1, 2, 3],\n            \"O6\": [2, 1, 2, 3],\n            \"O7\": [2, 1, 1, 1],\n            \"O7b\": [2, 1, 1, 1],\n            \"O8\": [3, 3, 1, 3],\n            \"O9\": [1, 3, 1, 3],\n        }\n    )\n\n    assert_eq(return_df, expected_df, check_dtype=False, check_index=False)\n\n\n@pytest.mark.xfail(\n    reason=\"Need to add single_value window function, see https://github.com/dask-contrib/dask-sql/issues/651\"\n)\ndef test_over_single_value(c, user_table_1):\n    return_df = c.sql(\n        \"\"\"\n    SELECT\n        user_id,\n        b,\n        SINGLE_VALUE(user_id*10 - b) OVER (PARTITION BY user_id ORDER BY b) AS \"O3\",\n    FROM user_table_1\n    \"\"\"\n    )\n    expected_df = pd.DataFrame(\n        {\n            \"user_id\": user_table_1.user_id,\n            \"b\": user_table_1.b,\n            \"O3\": [19, 7, 19, 27],\n        }\n    )\n\n    assert_eq(return_df, expected_df, check_dtype=False, check_index=False)\n\n\n# TODO: investigate source of window count deadlocks\n@skipif_dask_expr_enabled(\"Deadlocks with query planning enabled\")\ndef test_over_with_windows(c):\n    tmp_df = pd.DataFrame({\"a\": range(5)})\n    c.create_table(\"tmp\", tmp_df)\n\n    return_df = c.sql(\n        \"\"\"\n    SELECT\n        a,\n        SUM(a) OVER (ORDER BY a ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS \"O1\",\n        SUM(a) OVER (ORDER BY a ROWS BETWEEN 2 PRECEDING AND 3 FOLLOWING) AS \"O2\",\n        SUM(a) OVER (ORDER BY a ROWS BETWEEN 2 PRECEDING AND UNBOUNDED FOLLOWING) AS \"O3\",\n        SUM(a) OVER (ORDER BY a ROWS BETWEEN CURRENT ROW AND 3 FOLLOWING) AS \"O4\",\n        SUM(a) OVER (ORDER BY a ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) AS \"O5\",\n        SUM(a) OVER (ORDER BY a ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS \"O6\",\n        SUM(a) OVER (ORDER BY a ROWS BETWEEN UNBOUNDED PRECEDING AND 3 FOLLOWING) AS \"O7\",\n        SUM(a) OVER (ORDER BY a ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS \"O8\",\n        SUM(a) OVER (ORDER BY a ROWS BETWEEN 3 FOLLOWING AND 3 FOLLOWING) AS \"O9\",\n        COUNT(a) OVER (ORDER BY a ROWS BETWEEN 3 FOLLOWING AND 3 FOLLOWING) AS \"O9a\",\n        SUM(a) OVER (ORDER BY a ROWS BETWEEN 3 PRECEDING AND 1 PRECEDING) AS \"O10\"\n    FROM tmp\n    \"\"\"\n    )\n    expected_df = pd.DataFrame(\n        {\n            \"a\": return_df.a,\n            \"O1\": [0, 1, 3, 6, 9],\n            \"O2\": [6, 10, 10, 10, 9],\n            \"O3\": [10, 10, 10, 10, 9],\n            \"O4\": [6, 10, 9, 7, 4],\n            \"O5\": [10, 10, 9, 7, 4],\n            \"O6\": [0, 1, 3, 6, 10],\n            \"O7\": [6, 10, 10, 10, 10],\n            \"O8\": [10, 10, 10, 10, 10],\n            \"O9\": [3, 4, None, None, None],\n            \"O9a\": [1, 1, 0, 0, 0],\n            \"O10\": [None, 0, 1, 3, 6],\n        }\n    )\n\n    assert_eq(return_df, expected_df, check_dtype=False, check_index=False)\n"
  },
  {
    "path": "tests/integration/test_postgres.py",
    "content": "import sys\n\nimport pytest\n\npytestmark = pytest.mark.xfail(\n    condition=sys.platform in (\"win32\", \"darwin\"),\n    reason=\"hive testing not supported on Windows/macOS\",\n)\ndocker = pytest.importorskip(\"docker\")\nsqlalchemy = pytest.importorskip(\"sqlalchemy\")\n\n\n@pytest.fixture(scope=\"session\")\ndef engine():\n    client = docker.from_env()\n\n    network = client.networks.create(\"dask-sql\", driver=\"bridge\")\n    postgres = client.containers.run(\n        \"postgres:latest\",\n        detach=True,\n        remove=True,\n        network=\"dask-sql\",\n        environment={\"POSTGRES_HOST_AUTH_METHOD\": \"trust\"},\n    )\n\n    try:\n        # Wait for it to start\n        start_counter = 2\n        postgres.exec_run([\"bash\"])\n        for l in postgres.logs(stream=True):\n            if b\"database system is ready to accept connections\" in l:\n                start_counter -= 1\n\n            if start_counter == 0:\n                break\n\n        # get the address and create the connection\n        postgres.reload()\n        address = postgres.attrs[\"NetworkSettings\"][\"Networks\"][\"dask-sql\"][\"IPAddress\"]\n        port = 5432\n\n        engine = sqlalchemy.create_engine(\n            f\"postgresql+psycopg2://postgres@{address}:{port}/postgres\"\n        )\n        yield engine\n    except Exception:\n        postgres.kill()\n        network.remove()\n\n        raise\n\n    postgres.kill()\n    network.remove()\n\n\n@pytest.mark.xfail(reason=\"WIP DataFusion\")\ndef test_select(assert_query_gives_same_result):\n    assert_query_gives_same_result(\n        \"\"\"\n        SELECT * FROM df1\n    \"\"\"\n    )\n\n    assert_query_gives_same_result(\n        \"\"\"\n        SELECT\n            df1.user_id + 5,\n            2 * df1.a + df1.b / df1.user_id - df1.b,\n            df1.a IS NULL,\n            df1.a IS NOT NULL,\n            df1.b_bool IS TRUE,\n            df1.b_bool IS NOT TRUE,\n            df1.b_bool IS FALSE,\n            df1.b_bool IS NOT FALSE,\n            df1.b_bool IS UNKNOWN,\n            df1.b_bool IS NOT UNKNOWN,\n            ABS(df1.a),\n            ACOS(df1.a),\n            ASIN(df1.a),\n            ATAN(df1.a),\n            ATAN2(df1.a, df1.b),\n            CBRT(df1.a),\n            CEIL(df1.a),\n            COS(df1.a),\n            COT(df1.a),\n            DEGREES(df1.a),\n            EXP(df1.a),\n            FLOOR(df1.a),\n            LOG10(df1.a),\n            LN(df1.a),\n            POWER(df1.a, 3),\n            POWER(df1.a, -3),\n            POWER(df1.a, 1.1),\n            RADIANS(df1.a),\n            ROUND(df1.a),\n            SIGN(df1.a),\n            SIN(df1.a),\n            TAN(df1.a)\n        FROM df1\n    \"\"\"\n    )\n\n    assert_query_gives_same_result(\n        \"\"\"\n        SELECT df2.user_id, df2.d FROM df2\n    \"\"\"\n    )\n\n    assert_query_gives_same_result(\n        \"\"\"\n        SELECT 1 AS I, -5.34344 AS F, 'öäll' AS S\n    \"\"\"\n    )\n\n    assert_query_gives_same_result(\n        \"\"\"\n        SELECT CASE WHEN user_id <> 3 THEN 4 ELSE 2 END FROM df2\n    \"\"\"\n    )\n\n\ndef test_join(assert_query_gives_same_result):\n    assert_query_gives_same_result(\n        \"\"\"\n        SELECT\n            df1.user_id, df1.a, df1.b,\n            df2.user_id AS user_id_2, df2.c, df2.d\n        FROM df1\n        JOIN df2 ON df1.user_id = df2.user_id\n    \"\"\",\n        [\"user_id\", \"a\", \"b\", \"user_id_2\", \"c\", \"d\"],\n    )\n\n\ndef test_sort(assert_query_gives_same_result):\n    assert_query_gives_same_result(\n        \"\"\"\n        SELECT\n            user_id, b\n        FROM df1\n        ORDER BY b NULLS FIRST, user_id DESC NULLS FIRST\n    \"\"\"\n    )\n\n    assert_query_gives_same_result(\n        \"\"\"\n        SELECT\n            c, d\n        FROM df2\n        ORDER BY c NULLS FIRST, d NULLS FIRST, user_id NULLS FIRST\n    \"\"\"\n    )\n\n\ndef test_limit(assert_query_gives_same_result):\n    assert_query_gives_same_result(\n        \"\"\"\n        SELECT\n            c, d\n        FROM df2\n        ORDER BY c NULLS FIRST, d NULLS FIRST, user_id NULLS FIRST\n        LIMIT 10 OFFSET 20\n    \"\"\"\n    )\n\n    assert_query_gives_same_result(\n        \"\"\"\n        SELECT\n            c, d\n        FROM df2\n        ORDER BY c NULLS FIRST, d NULLS FIRST, user_id NULLS FIRST\n        LIMIT 200\n    \"\"\"\n    )\n\n\n@pytest.mark.xfail(reason=\"WIP DataFusion\")\ndef test_groupby(assert_query_gives_same_result):\n    assert_query_gives_same_result(\n        \"\"\"\n        SELECT\n            d, SUM(1.0 * c), AVG(1.0 * user_id)\n        FROM df2\n        WHERE d IS NOT NULL -- dask behaves differently on NaNs in groupbys\n        GROUP BY d\n        ORDER BY SUM(c)\n        LIMIT 10\n    \"\"\"\n    )\n\n\ndef test_filter(assert_query_gives_same_result):\n    assert_query_gives_same_result(\n        \"\"\"\n        SELECT\n            a\n        FROM df1\n        WHERE\n            user_id = 3 AND a > 0.5\n    \"\"\"\n    )\n\n    assert_query_gives_same_result(\n        \"\"\"\n        SELECT\n            d\n        FROM df2\n        WHERE\n            d NOT LIKE '%%c'\n    \"\"\"\n    )\n\n    assert_query_gives_same_result(\n        \"\"\"\n        SELECT\n            d\n        FROM df2\n        WHERE\n            (d NOT LIKE '%%c') IS NULL\n    \"\"\"\n    )\n\n\ndef test_string_operations(assert_query_gives_same_result):\n    assert_query_gives_same_result(\n        \"\"\"\n        SELECT\n            s,\n            s || 'hello' || s,\n            s SIMILAR TO '%%(b|d)%%',\n            s SIMILAR TO '%%(B|c)%%',\n            s SIMILAR TO '%%[a-zA-Z]%%',\n            s SIMILAR TO '.*',\n            s NOT SIMILAR TO '.*',\n            s LIKE '%%(b|d)%%',\n            s LIKE '%%(B|c)%%',\n            s LIKE '%%[a-zA-Z]%%',\n            s LIKE '.*',\n            S NOT LIKE '.*',\n            s ILIKE '%%(b|d)%%',\n            s ILIKE '%%(B|c)%%',\n            s NOT ILIKE '%%(b|d)%%',\n            s NOT ILIKE '%%(B|c)%%',\n            CHAR_LENGTH(s),\n            UPPER(s),\n            LOWER(s),\n            TRIM('a' FROM s),\n            TRIM(BOTH 'a' FROM s),\n            TRIM(LEADING 'a' FROM s),\n            TRIM(TRAILING 'a' FROM s),\n            SUBSTRING(s FROM -1),\n            SUBSTRING(s FROM 10),\n            SUBSTRING(s FROM 2),\n            SUBSTRING(s FROM 2 FOR 2),\n            SUBSTR(s,2,2) as s2,\n            INITCAP(s),\n            INITCAP(UPPER(s)),\n            INITCAP(LOWER(s))\n        FROM df3\n    \"\"\"\n    )\n\n\n@pytest.mark.xfail(reason=\"POSITION syntax not supported by parser\")\ndef test_string_position(assert_query_gives_same_result):\n    assert_query_gives_same_result(\n        \"\"\"\n        SELECT\n            POSITION('a' IN s),\n            POSITION('ZL' IN s)\n        FROM df3\n    \"\"\"\n    )\n\n\n@pytest.mark.xfail(reason=\"OVERLAY syntax not supported by parser\")\ndef test_string_overlay(assert_query_gives_same_result):\n    assert_query_gives_same_result(\n        \"\"\"\n        SELECT\n            OVERLAY(s PLACING 'XXX' FROM 2),\n            OVERLAY(s PLACING 'XXX' FROM 2 FOR 4),\n            OVERLAY(s PLACING 'XXX' FROM 2 FOR 1)\n        FROM df3\n    \"\"\"\n    )\n\n\n@pytest.mark.xfail(reason=\"WIP DataFusion\")\ndef test_statistical_functions(assert_query_gives_same_result):\n\n    # test regr_count\n    assert_query_gives_same_result(\n        \"\"\"\n        select user_id, REGR_COUNT(a,b) FROM df1 GROUP BY user_id\n        \"\"\",\n        [\"user_id\"],\n        check_names=False,\n    )\n\n    assert_query_gives_same_result(\n        \"\"\"\n        select user_id, REGR_SXX(a, 1.0 * b) FROM df1 GROUP BY user_id\n        \"\"\",\n        [\"user_id\"],\n        check_names=False,\n    )\n\n    assert_query_gives_same_result(\n        \"\"\"\n        select user_id, REGR_SYY(a, 1.0 * b) FROM df1 GROUP BY user_id\n        \"\"\",\n        [\"user_id\"],\n        check_names=False,\n    )\n\n    assert_query_gives_same_result(\n        \"\"\"\n        select user_id, COVAR_POP(a, b) FROM df1 GROUP BY user_id\n        \"\"\",\n        [\"user_id\"],\n        check_names=False,\n    )\n\n    assert_query_gives_same_result(\n        \"\"\"\n        select user_id,COVAR_SAMP(a,b) FROM df1 GROUP BY user_id\n        \"\"\",\n        [\"user_id\"],\n        check_names=False,\n    )\n"
  },
  {
    "path": "tests/integration/test_rex.py",
    "content": "from datetime import datetime\n\nimport dask.dataframe as dd\nimport numpy as np\nimport pandas as pd\nimport pytest\n\nfrom tests.utils import assert_eq\n\n\ndef test_year(c, datetime_table):\n    result_df = c.sql(\n        \"\"\"\n    SELECT year(timezone) from datetime_table\n    \"\"\"\n    )\n    assert result_df.shape[0].compute() == datetime_table.shape[0]\n    assert result_df.compute().iloc[0][0] == 2014\n\n\ndef test_case(c, df):\n    result_df = c.sql(\n        \"\"\"\n    SELECT\n        (CASE WHEN a = 3 THEN 1 END) AS \"S1\",\n        (CASE WHEN a > 0 THEN a ELSE 1 END) AS \"S2\",\n        (CASE WHEN a = 4 THEN 3 ELSE a + 1 END) AS \"S3\",\n        (CASE WHEN a = 3 THEN 1 WHEN a > 0 THEN 2 ELSE a END) AS \"S4\",\n        CASE\n            WHEN (a >= 1 AND a < 2) OR (a > 2) THEN CAST('in-between' AS VARCHAR) ELSE CAST('out-of-range' AS VARCHAR)\n        END AS \"S5\",\n        CASE\n            WHEN (a < 2) OR (3 < a AND a < 4) THEN 42 ELSE 47\n        END AS \"S6\",\n        CASE WHEN (1 < a AND a <= 4) THEN 1 ELSE 0 END AS \"S7\",\n        CASE a WHEN 2 THEN 5 ELSE a + 1 END AS \"S8\"\n    FROM df\n    \"\"\"\n    )\n    expected_df = pd.DataFrame(index=df.index)\n    expected_df[\"S1\"] = df.a.apply(lambda a: 1 if a == 3 else np.NaN)\n    expected_df[\"S2\"] = df.a.apply(lambda a: a if a > 0 else 1)\n    expected_df[\"S3\"] = df.a.apply(lambda a: 3 if a == 4 else a + 1).astype(\"Int64\")\n    expected_df[\"S4\"] = df.a.apply(lambda a: 1 if a == 3 else 2 if a > 0 else a).astype(\n        \"Int64\"\n    )\n    expected_df[\"S5\"] = df.a.apply(\n        lambda a: \"in-between\" if ((1 <= a < 2) or (a > 2)) else \"out-of-range\"\n    )\n    expected_df[\"S6\"] = df.a.apply(lambda a: 42 if ((a < 2) or (3 < a < 4)) else 47)\n    expected_df[\"S7\"] = df.a.apply(lambda a: 1 if (1 < a <= 4) else 0)\n    expected_df[\"S8\"] = df.a.apply(lambda a: 5 if a == 2 else a + 1).astype(\"Int64\")\n\n    assert_eq(result_df, expected_df)\n\n\ndef test_intervals(c):\n    df = c.sql(\n        \"\"\"SELECT INTERVAL '3' DAY as \"IN\"\n        \"\"\"\n    )\n    expected_df = pd.DataFrame(\n        {\n            \"IN\": [pd.to_timedelta(\"3d\")],\n        }\n    )\n    assert_eq(df, expected_df)\n\n    date1 = datetime(2021, 10, 3, 15, 53, 42, 47)\n    date2 = datetime(2021, 2, 28, 15, 53, 42, 47)\n    dates = dd.from_pandas(pd.DataFrame({\"d\": [date1, date2]}), npartitions=1)\n    c.create_table(\"dates\", dates)\n    df = c.sql(\n        \"\"\"SELECT d + INTERVAL '5 days' AS \"Plus_5_days\" FROM dates\n        \"\"\"\n    )\n    expected_df = pd.DataFrame(\n        {\n            \"Plus_5_days\": [\n                datetime(2021, 10, 8, 15, 53, 42, 47),\n                datetime(2021, 3, 5, 15, 53, 42, 47),\n            ]\n        }\n    )\n    assert_eq(df, expected_df)\n\n\ndef test_literals(c):\n    df = c.sql(\n        \"\"\"SELECT 'a string äö' AS \"S\",\n                    4.4 AS \"F\",\n                    -4564347464 AS \"I\",\n                    TIME '08:08:00.091' AS \"T\",\n                    TIMESTAMP '2022-04-06 17:33:21' AS \"DT\",\n                    DATE '1991-06-02' AS \"D\",\n                    INTERVAL '1' DAY AS \"IN\"\n        \"\"\"\n    )\n\n    expected_df = pd.DataFrame(\n        {\n            \"S\": [\"a string äö\"],\n            \"F\": [4.4],\n            \"I\": [-4564347464],\n            \"T\": [pd.to_datetime(\"1970-01-01 08:08:00.091\")],\n            \"DT\": [pd.to_datetime(\"2022-04-06 17:33:21\")],\n            \"D\": [pd.to_datetime(\"1991-06-02 00:00\")],\n            \"IN\": [pd.to_timedelta(\"1d\")],\n        }\n    )\n    assert_eq(df, expected_df)\n\n\ndef test_date_interval_math(c):\n    df = c.sql(\n        \"\"\"SELECT\n                DATE '1998-08-18' - INTERVAL '4 days' AS \"before\",\n                DATE '1998-08-18' + INTERVAL '4 days' AS \"after\"\n        \"\"\"\n    )\n\n    expected_df = pd.DataFrame(\n        {\n            \"before\": [pd.to_datetime(\"1998-08-14 00:00\")],\n            \"after\": [pd.to_datetime(\"1998-08-22 00:00\")],\n        }\n    )\n    assert_eq(df, expected_df)\n\n\ndef test_literal_null(c):\n    df = c.sql(\n        \"\"\"\n    SELECT NULL AS \"N\", 1 + NULL AS \"I\"\n    \"\"\"\n    )\n\n    expected_df = pd.DataFrame({\"N\": [pd.NA], \"I\": [pd.NA]})\n    expected_df[\"I\"] = expected_df[\"I\"].astype(\"Int64\")\n    assert_eq(df, expected_df)\n\n\ndef test_random(c):\n    query_with_seed = \"\"\"\n            SELECT\n                RAND(0) AS \"0\",\n                RAND_INTEGER(0, 10) AS \"1\"\n            \"\"\"\n\n    result_df = c.sql(query_with_seed)\n\n    # assert that repeated queries give the same result\n    assert_eq(result_df, c.sql(query_with_seed))\n\n    # assert output\n    result_df = result_df.compute()\n\n    assert result_df[\"0\"].dtype == \"float64\"\n    assert result_df[\"1\"].dtype == \"Int64\"\n\n    assert 0 <= result_df[\"0\"][0] < 1\n    assert 0 <= result_df[\"1\"][0] < 10\n\n    query_wo_seed = \"\"\"\n        SELECT\n            RAND() AS \"0\",\n            RANDOM() AS \"1\",\n            RAND_INTEGER(30) AS \"2\"\n        \"\"\"\n    result_df = c.sql(query_wo_seed)\n    result_df = result_df.compute()\n    # assert output types\n\n    assert result_df[\"0\"].dtype == \"float64\"\n    assert result_df[\"1\"].dtype == \"float64\"\n    assert result_df[\"2\"].dtype == \"Int64\"\n\n    assert 0 <= result_df[\"0\"][0] < 1\n    assert 0 <= result_df[\"1\"][0] < 1\n    assert 0 <= result_df[\"2\"][0] < 30\n\n\n@pytest.mark.parametrize(\n    \"input_table\",\n    [\n        \"string_table\",\n        pytest.param(\"gpu_string_table\", marks=pytest.mark.gpu),\n    ],\n)\ndef test_not(c, input_table, request):\n    string_table = request.getfixturevalue(input_table)\n    df = c.sql(\n        f\"\"\"\n    SELECT\n        *\n    FROM {input_table}\n    WHERE NOT a LIKE '%normal%'\n    \"\"\"\n    )\n\n    expected_df = string_table[~string_table.a.str.contains(\"normal\")]\n    assert_eq(df, expected_df)\n\n\ndef test_operators(c, df):\n    result_df = c.sql(\n        \"\"\"\n    SELECT\n        a * b AS m,\n        -a AS u,\n        a / b AS q,\n        a + b AS s,\n        a - b AS d,\n        a = b AS e,\n        a > b AS g,\n        a >= b AS ge,\n        a < b AS l,\n        a <= b AS le,\n        a <> b AS n\n    FROM df\n    \"\"\"\n    )\n\n    expected_df = pd.DataFrame(index=df.index)\n    expected_df[\"m\"] = df[\"a\"] * df[\"b\"]\n    expected_df[\"u\"] = -df[\"a\"]\n    expected_df[\"q\"] = df[\"a\"] / df[\"b\"]\n    expected_df[\"s\"] = df[\"a\"] + df[\"b\"]\n    expected_df[\"d\"] = df[\"a\"] - df[\"b\"]\n    expected_df[\"e\"] = df[\"a\"] == df[\"b\"]\n    expected_df[\"g\"] = df[\"a\"] > df[\"b\"]\n    expected_df[\"ge\"] = df[\"a\"] >= df[\"b\"]\n    expected_df[\"l\"] = df[\"a\"] < df[\"b\"]\n    expected_df[\"le\"] = df[\"a\"] <= df[\"b\"]\n    expected_df[\"n\"] = df[\"a\"] != df[\"b\"]\n    assert_eq(result_df, expected_df)\n\n\n@pytest.mark.parametrize(\n    \"input_table,gpu\",\n    [\n        (\"string_table\", False),\n        pytest.param(\n            \"gpu_string_table\",\n            True,\n            marks=(\n                pytest.mark.gpu,\n                pytest.mark.xfail(\n                    reason=\"Failing due to cuDF bug https://github.com/rapidsai/cudf/issues/9434\"\n                ),\n            ),\n        ),\n    ],\n)\ndef test_like(c, input_table, gpu, request):\n    string_table = request.getfixturevalue(input_table)\n\n    df = c.sql(\n        f\"\"\"\n        SELECT * FROM {input_table}\n        WHERE a SIMILAR TO '%n[a-z]rmal st_i%'\n    \"\"\"\n    )\n    assert_eq(df, string_table.iloc[[0, 3]])\n\n    df = c.sql(\n        f\"\"\"\n        SELECT * FROM {input_table}\n        WHERE a NOT SIMILAR TO '%n[a-z]rmal st_i%'\n    \"\"\"\n    )\n    assert_eq(df, string_table.iloc[[1, 2]])\n\n    df = c.sql(\n        f\"\"\"\n        SELECT * FROM {input_table}\n        WHERE a LIKE '%n[a-z]rmal st_i%'\n    \"\"\"\n    )\n    assert len(df) == 0\n\n    df = c.sql(\n        f\"\"\"\n        SELECT * FROM {input_table}\n        WHERE a NOT LIKE '%n[a-z]rmal st_i%'\n    \"\"\"\n    )\n    assert_eq(df, string_table)\n\n    df = c.sql(\n        f\"\"\"\n        SELECT * FROM {input_table}\n        WHERE a LIKE '%a Normal String%'\n    \"\"\"\n    )\n    assert len(df) == 0\n\n    df = c.sql(\n        f\"\"\"\n        SELECT * FROM {input_table}\n        WHERE a ILIKE '%a Normal String%'\n    \"\"\"\n    )\n    assert_eq(df, string_table.iloc[[0, 3]])\n\n    df = c.sql(\n        f\"\"\"\n        SELECT * FROM {input_table}\n        WHERE a NOT ILIKE '%a Normal String%'\n    \"\"\"\n    )\n    assert_eq(df, string_table.iloc[[1, 2]])\n    # TODO: uncomment when sqlparser adds parsing support for non-standard escape characters\n    # https://github.com/dask-contrib/dask-sql/issues/754\n    # df = c.sql(\n    #     f\"\"\"\n    #     SELECT * FROM {input_table}\n    #     WHERE a LIKE 'Ä%Ä_Ä%' ESCAPE 'Ä'\n    # \"\"\"\n    # )\n\n    # assert_eq(df, string_table.iloc[[1]])\n\n    df = c.sql(\n        f\"\"\"\n        SELECT * FROM {input_table}\n        WHERE a SIMILAR TO '^|()-*r[r]$' ESCAPE 'r'\n        \"\"\"\n    )\n\n    assert_eq(df, string_table.iloc[[2, 3]])\n\n    df = c.sql(\n        f\"\"\"\n        SELECT * FROM {input_table}\n        WHERE a LIKE '^|()-*r[r]$' ESCAPE 'r'\n    \"\"\"\n    )\n\n    assert_eq(df, string_table.iloc[[2]])\n\n    df = c.sql(\n        f\"\"\"\n        SELECT * FROM {input_table}\n        WHERE a LIKE '%_' ESCAPE 'r'\n    \"\"\"\n    )\n\n    assert_eq(df, string_table)\n\n    string_table2 = pd.DataFrame({\"b\": [\"a\", \"b\", None, pd.NA, float(\"nan\")]})\n    c.create_table(\"string_table2\", string_table2, gpu=gpu)\n    df = c.sql(\n        \"\"\"\n        SELECT * FROM string_table2\n        WHERE b LIKE 'b'\n    \"\"\"\n    )\n\n    assert_eq(df, string_table2.iloc[[1]])\n\n\ndef test_null(c):\n    df = c.sql(\n        \"\"\"\n        SELECT\n            c IS NOT NULL AS nn,\n            c IS NULL AS n\n        FROM user_table_nan\n    \"\"\"\n    )\n\n    expected_df = pd.DataFrame(index=[0, 1, 2])\n    expected_df[\"nn\"] = [True, False, True]\n    expected_df[\"nn\"] = expected_df[\"nn\"].astype(\"boolean\")\n    expected_df[\"n\"] = [False, True, False]\n    assert_eq(df, expected_df)\n\n    df = c.sql(\n        \"\"\"\n        SELECT\n            a IS NOT NULL AS nn,\n            a IS NULL AS n\n        FROM string_table\n    \"\"\"\n    )\n\n    expected_df = pd.DataFrame(index=[0, 1, 2, 3])\n    expected_df[\"nn\"] = [True, True, True, True]\n    expected_df[\"nn\"] = expected_df[\"nn\"].astype(\"boolean\")\n    expected_df[\"n\"] = [False, False, False, False]\n    assert_eq(df, expected_df)\n\n\n@pytest.mark.filterwarnings(\n    \"ignore:divide by zero:RuntimeWarning:dask_sql.physical.rex.core.call\"\n)\n@pytest.mark.parametrize(\"gpu\", [False, pytest.param(True, marks=pytest.mark.gpu)])\ndef test_coalesce(c, gpu):\n    df = dd.from_pandas(\n        pd.DataFrame({\"a\": [1, 2, 3], \"b\": [np.nan] * 3}), npartitions=1\n    )\n    c.create_table(\"df\", df, gpu=gpu)\n\n    df = c.sql(\n        \"\"\"\n        SELECT\n            COALESCE(3, 5) as c1,\n            COALESCE(NULL, NULL) as c2,\n            COALESCE(NULL, 'hi') as c3,\n            COALESCE(NULL, NULL, 'bye', 5/0) as c4,\n            COALESCE(NULL, 3/2, NULL, 'fly') as c5,\n            COALESCE(NULL, MEAN(b), MEAN(a), 4/0) as c6\n        FROM df\n        \"\"\"\n    )\n\n    expected_df = pd.DataFrame(\n        {\n            \"c1\": [3],\n            \"c2\": [pd.NA],\n            \"c3\": [\"hi\"],\n            \"c4\": [\"bye\"],\n            \"c5\": [\"1.5\"],\n            \"c6\": [2.0],\n        }\n    )\n    expected_df[\"c2\"] = expected_df[\"c2\"].astype(\"Int8\")\n\n    assert_eq(df, expected_df, check_dtype=False)\n\n    df = c.sql(\n        \"\"\"\n        SELECT\n            COALESCE(a, b) as c1,\n            COALESCE(b, a) as c2,\n            COALESCE(a, a) as c3,\n            COALESCE(b, b) as c4\n        FROM df\n        \"\"\"\n    )\n\n    expected_df = pd.DataFrame(\n        {\n            \"c1\": [1, 2, 3],\n            \"c2\": [1, 2, 3],\n            \"c3\": [1, 2, 3],\n            \"c4\": [np.nan] * 3,\n        }\n    )\n\n    assert_eq(df, expected_df, check_dtype=False)\n    c.drop_table(\"df\")\n\n\ndef test_boolean_operations(c):\n    df = dd.from_pandas(pd.DataFrame({\"b\": [1, 0, -1]}), npartitions=1)\n    df[\"b\"] = df[\"b\"].apply(\n        lambda x: pd.NA if x < 0 else x > 0, meta=(\"b\", \"bool\")\n    )  # turn into a bool column\n    c.create_table(\"df\", df)\n\n    result_df = c.sql(\n        \"\"\"\n        SELECT\n            b IS TRUE AS t,\n            b IS FALSE AS f,\n            b IS NOT TRUE AS nt,\n            b IS NOT FALSE AS nf,\n            b IS UNKNOWN AS u,\n            b IS NOT UNKNOWN AS nu\n        FROM df\"\"\"\n    )\n\n    expected_df = pd.DataFrame(\n        {\n            \"t\": df.b.astype(\"boolean\").fillna(False),\n            \"f\": ~df.b.astype(\"boolean\").fillna(True),\n            \"nt\": ~df.b.astype(\"boolean\").fillna(False),\n            \"nf\": df.b.astype(\"boolean\").fillna(True),\n            \"u\": df.b.isna(),\n            \"nu\": ~df.b.isna().astype(\"boolean\"),\n        },\n    )\n    assert_eq(result_df, expected_df, check_dtype=False)\n\n\ndef test_math_operations(c, df):\n    result_df = c.sql(\n        \"\"\"\n        SELECT\n            ABS(b) AS \"abs\"\n            , ACOS(b) AS \"acos\"\n            , ASIN(b) AS \"asin\"\n            , ATAN(b) AS \"atan\"\n            , ATAN2(a, b) AS \"atan2\"\n            , CBRT(b) AS \"cbrt\"\n            , CEIL(b) AS \"ceil\"\n            , COS(b) AS \"cos\"\n            , COT(b) AS \"cot\"\n            , DEGREES(b) AS \"degrees\"\n            , EXP(b) AS \"exp\"\n            , FLOOR(b) AS \"floor\"\n            , LOG10(b) AS \"log10\"\n            , LN(b) AS \"ln\"\n            , MOD(b, 4) AS \"mod\"\n            , POWER(b, 2) AS \"power\"\n            , POWER(b, a) AS \"power2\"\n            , RADIANS(b) AS \"radians\"\n            , ROUND(b) AS \"round\"\n            , ROUND(b, 3) AS \"round2\"\n            , SIGN(b) AS \"sign\"\n            , SIN(b) AS \"sin\"\n            , TAN(b) AS \"tan\"\n            , TRUNCATE(b) AS \"truncate\"\n        FROM df\n    \"\"\"\n    )\n\n    expected_df = pd.DataFrame(index=df.index)\n    expected_df[\"abs\"] = df.b.abs()\n    expected_df[\"acos\"] = np.arccos(df.b)\n    expected_df[\"asin\"] = np.arcsin(df.b)\n    expected_df[\"atan\"] = np.arctan(df.b)\n    expected_df[\"atan2\"] = np.arctan2(df.a, df.b)\n    expected_df[\"cbrt\"] = np.cbrt(df.b)\n    expected_df[\"ceil\"] = np.ceil(df.b)\n    expected_df[\"cos\"] = np.cos(df.b)\n    expected_df[\"cot\"] = 1 / np.tan(df.b)\n    expected_df[\"degrees\"] = df.b / np.pi * 180\n    expected_df[\"exp\"] = np.exp(df.b)\n    expected_df[\"floor\"] = np.floor(df.b)\n    expected_df[\"log10\"] = np.log10(df.b)\n    expected_df[\"ln\"] = np.log(df.b)\n    expected_df[\"mod\"] = np.mod(df.b, 4)\n    expected_df[\"power\"] = np.power(df.b, 2)\n    expected_df[\"power2\"] = np.power(df.b, df.a)\n    expected_df[\"radians\"] = df.b / 180 * np.pi\n    expected_df[\"round\"] = np.round(df.b)\n    expected_df[\"round2\"] = np.round(df.b, 3)\n    expected_df[\"sign\"] = np.sign(df.b)\n    expected_df[\"sin\"] = np.sin(df.b)\n    expected_df[\"tan\"] = np.tan(df.b)\n    expected_df[\"truncate\"] = np.trunc(df.b)\n    assert_eq(result_df, expected_df)\n\n\ndef test_integer_div(c, df_simple):\n    df = c.sql(\n        \"\"\"\n        SELECT\n            1 / a AS a,\n            a / 2 AS b,\n            1.0 / a AS c\n        FROM df_simple\n    \"\"\"\n    )\n\n    expected_df = pd.DataFrame(\n        {\n            \"a\": (1 // df_simple.a).astype(\"Int64\"),\n            \"b\": (df_simple.a // 2).astype(\"Int64\"),\n            \"c\": 1 / df_simple.a,\n        }\n    )\n\n    assert_eq(df, expected_df)\n\n\n@pytest.mark.xfail(reason=\"Subquery expressions not yet enabled\")\ndef test_subqueries(c, user_table_1, user_table_2):\n    df = c.sql(\n        \"\"\"\n        SELECT *\n        FROM\n            user_table_2\n        WHERE\n            EXISTS(\n                SELECT *\n                FROM user_table_1\n                WHERE\n                    user_table_1.b = user_table_2.c\n            )\n    \"\"\"\n    )\n\n    assert_eq(df, user_table_2[user_table_2.c.isin(user_table_1.b)], check_index=False)\n\n\n@pytest.mark.parametrize(\"gpu\", [False, pytest.param(True, marks=pytest.mark.gpu)])\ndef test_string_functions(c, gpu):\n    if gpu:\n        input_table = \"gpu_string_table\"\n    else:\n        input_table = \"string_table\"\n\n    df = c.sql(\n        f\"\"\"\n        SELECT\n            a || 'hello' || a AS a,\n            CONCAT(a, 'hello', a) as b,\n            CHAR_LENGTH(a) AS c,\n            UPPER(a) AS d,\n            LOWER(a) AS e,\n            -- POSITION('a' IN a FROM 4) AS f,\n            -- POSITION('ZL' IN a) AS g,\n            TRIM('a' FROM a) AS h,\n            TRIM(BOTH 'a' FROM a) AS i,\n            TRIM(LEADING 'a' FROM a) AS j,\n            TRIM(TRAILING 'a' FROM a) AS k,\n            -- OVERLAY(a PLACING 'XXX' FROM -1) AS l,\n            -- OVERLAY(a PLACING 'XXX' FROM 2 FOR 4) AS m,\n            -- OVERLAY(a PLACING 'XXX' FROM 2 FOR 1) AS n,\n            SUBSTRING(a FROM -1) AS o,\n            SUBSTRING(a FROM 10) AS p,\n            SUBSTRING(a FROM 2) AS q,\n            SUBSTRING(a FROM 2 FOR 2) AS r,\n            SUBSTR(a, 3, 6) AS s,\n            INITCAP(a) AS t,\n            INITCAP(UPPER(a)) AS u,\n            INITCAP(LOWER(a)) AS v,\n            REPLACE(a, 'r', 'l') as w,\n            REPLACE('Another String', 'th', 'b') as x\n        FROM\n            {input_table}\n        \"\"\"\n    )\n\n    if gpu:\n        df = df.astype({\"c\": \"int64\"})  # , \"f\": \"int64\", \"g\": \"int64\"})\n\n    expected_df = pd.DataFrame(\n        {\n            \"a\": [\"a normal stringhelloa normal string\"],\n            \"b\": [\"a normal stringhelloa normal string\"],\n            \"c\": [15],\n            \"d\": [\"A NORMAL STRING\"],\n            \"e\": [\"a normal string\"],\n            # \"f\": [7], # position from syntax not supported\n            # \"g\": [0],\n            \"h\": [\" normal string\"],\n            \"i\": [\" normal string\"],\n            \"j\": [\" normal string\"],\n            \"k\": [\"a normal string\"],\n            # \"l\": [\"XXXormal string\"], # overlay from syntax not supported by parser\n            # \"m\": [\"aXXXmal string\"],\n            # \"n\": [\"aXXXnormal string\"],\n            \"o\": [\"a normal string\"],\n            \"p\": [\"string\"],\n            \"q\": [\" normal string\"],\n            \"r\": [\" n\"],\n            \"s\": [\"normal\"],\n            \"t\": [\"A Normal String\"],\n            \"u\": [\"A Normal String\"],\n            \"v\": [\"A Normal String\"],\n            \"w\": [\"a nolmal stling\"],\n            \"x\": [\"Anober String\"],\n        }\n    )\n\n    assert_eq(\n        df.head(1),\n        expected_df,\n    )\n\n\n@pytest.mark.xfail(reason=\"POSITION syntax not supported by parser\")\n@pytest.mark.parametrize(\"gpu\", [False, pytest.param(True, marks=pytest.mark.gpu)])\ndef test_string_position(c, gpu):\n    if gpu:\n        input_table = \"gpu_string_table\"\n    else:\n        input_table = \"string_table\"\n\n    df = c.sql(\n        f\"\"\"\n        SELECT\n            POSITION('a' IN a FROM 4) AS f,\n            POSITION('ZL' IN a) AS g,\n        FROM\n            {input_table}\n        \"\"\"\n    )\n\n    if gpu:\n        df = df.astype({\"f\": \"int64\", \"g\": \"int64\"})\n\n    expected_df = pd.DataFrame(\n        {\n            \"f\": [7],\n            \"g\": [0],\n        }\n    )\n\n    assert_eq(\n        df.head(1),\n        expected_df,\n    )\n\n\n@pytest.mark.xfail(reason=\"OVERLAY syntax not supported by parser\")\n@pytest.mark.parametrize(\"gpu\", [False, pytest.param(True, marks=pytest.mark.gpu)])\ndef test_string_overlay(c, gpu):\n    if gpu:\n        input_table = \"gpu_string_table\"\n    else:\n        input_table = \"string_table\"\n\n    df = c.sql(\n        f\"\"\"\n        SELECT\n            OVERLAY(a PLACING 'XXX' FROM -1) AS l,\n            OVERLAY(a PLACING 'XXX' FROM 2 FOR 4) AS m,\n            OVERLAY(a PLACING 'XXX' FROM 2 FOR 1) AS n,\n        FROM\n            {input_table}\n        \"\"\"\n    )\n\n    if gpu:\n        df = df.astype({\"c\": \"int64\"})  # , \"f\": \"int64\", \"g\": \"int64\"})\n\n    expected_df = pd.DataFrame(\n        {\n            \"l\": [\"XXXormal string\"],\n            \"m\": [\"aXXXmal string\"],\n            \"n\": [\"aXXXnormal string\"],\n        }\n    )\n\n    assert_eq(\n        df.head(1),\n        expected_df,\n    )\n\n\ndef test_date_functions(c):\n    date = datetime(2021, 10, 3, 15, 53, 42, 47)\n\n    df = dd.from_pandas(pd.DataFrame({\"d\": [date]}), npartitions=1)\n    c.create_table(\"df\", df)\n\n    df = c.sql(\n        \"\"\"\n        SELECT\n            EXTRACT(CENTURY FROM d) AS \"century\",\n            EXTRACT(DAY FROM d) AS \"day\",\n            EXTRACT(DECADE FROM d) AS \"decade\",\n            EXTRACT(DOW FROM d) AS \"dow\",\n            EXTRACT(DOY FROM d) AS \"doy\",\n            EXTRACT(HOUR FROM d) AS \"hour\",\n            EXTRACT(MICROSECONDS FROM d) AS \"microsecond\",\n            EXTRACT(MILLENNIUM FROM d) AS \"millennium\",\n            EXTRACT(MILLISECONDS FROM d) AS \"millisecond\",\n            EXTRACT(MINUTE FROM d) AS \"minute\",\n            EXTRACT(MONTH FROM d) AS \"month\",\n            EXTRACT(QUARTER FROM d) AS \"quarter\",\n            EXTRACT(SECOND FROM d) AS \"second\",\n            EXTRACT(WEEK FROM d) AS \"week\",\n            EXTRACT(YEAR FROM d) AS \"year\",\n            EXTRACT(DATE FROM d) AS \"date\",\n\n            LAST_DAY(d) as \"last_day\",\n\n            TIMESTAMPADD(YEAR, 1, d) as \"plus_1_year\",\n            TIMESTAMPADD(MONTH, 1, d) as \"plus_1_month\",\n            TIMESTAMPADD(WEEK, 1, d) as \"plus_1_week\",\n            TIMESTAMPADD(DAY, 1, d) as \"plus_1_day\",\n            TIMESTAMPADD(HOUR, 1, d) as \"plus_1_hour\",\n            TIMESTAMPADD(MINUTE, 1, d) as \"plus_1_min\",\n            TIMESTAMPADD(SECOND, 1, d) as \"plus_1_sec\",\n            TIMESTAMPADD(MICROSECOND, 999*1000, d) as \"plus_999_millisec\",\n            TIMESTAMPADD(MICROSECOND, 999, d) as \"plus_999_microsec\",\n            TIMESTAMPADD(QUARTER, 1, d) as \"plus_1_qt\",\n\n            CEIL(d TO DAY) as ceil_to_day,\n            CEIL(d TO HOUR) as ceil_to_hour,\n            CEIL(d TO MINUTE) as ceil_to_minute,\n            CEIL(d TO SECOND) as ceil_to_seconds,\n            CEIL(d TO MILLISECOND) as ceil_to_millisec,\n\n            FLOOR(d TO DAY) as floor_to_day,\n            FLOOR(d TO HOUR) as floor_to_hour,\n            FLOOR(d TO MINUTE) as floor_to_minute,\n            FLOOR(d TO SECOND) as floor_to_seconds,\n            FLOOR(d TO MILLISECOND) as floor_to_millisec\n\n        FROM df\n    \"\"\"\n    )\n\n    expected_df = pd.DataFrame(\n        {\n            \"century\": [20],\n            \"day\": [3],\n            \"decade\": [202],\n            \"dow\": [0],\n            \"doy\": [276],\n            \"hour\": [15],\n            \"microsecond\": [47],\n            \"millennium\": [2],\n            \"millisecond\": [47000],\n            \"minute\": [53],\n            \"month\": [10],\n            \"quarter\": [4],\n            \"second\": [42],\n            \"week\": [39],\n            \"year\": [2021],\n            \"date\": [datetime(2021, 10, 3)],\n            \"last_day\": [datetime(2021, 10, 31, 15, 53, 42, 47)],\n            \"plus_1_year\": [datetime(2022, 10, 3, 15, 53, 42, 47)],\n            \"plus_1_month\": [datetime(2021, 11, 3, 15, 53, 42, 47)],\n            \"plus_1_week\": [datetime(2021, 10, 10, 15, 53, 42, 47)],\n            \"plus_1_day\": [datetime(2021, 10, 4, 15, 53, 42, 47)],\n            \"plus_1_hour\": [datetime(2021, 10, 3, 16, 53, 42, 47)],\n            \"plus_1_min\": [datetime(2021, 10, 3, 15, 54, 42, 47)],\n            \"plus_1_sec\": [datetime(2021, 10, 3, 15, 53, 43, 47)],\n            \"plus_999_millisec\": [datetime(2021, 10, 3, 15, 53, 42, 1000 * 999 + 47)],\n            \"plus_999_microsec\": [datetime(2021, 10, 3, 15, 53, 42, 1046)],\n            \"plus_1_qt\": [datetime(2022, 1, 3, 15, 53, 42, 47)],\n            \"ceil_to_day\": [datetime(2021, 10, 4)],\n            \"ceil_to_hour\": [datetime(2021, 10, 3, 16)],\n            \"ceil_to_minute\": [datetime(2021, 10, 3, 15, 54)],\n            \"ceil_to_seconds\": [datetime(2021, 10, 3, 15, 53, 43)],\n            \"ceil_to_millisec\": [datetime(2021, 10, 3, 15, 53, 42, 1000)],\n            \"floor_to_day\": [datetime(2021, 10, 3)],\n            \"floor_to_hour\": [datetime(2021, 10, 3, 15)],\n            \"floor_to_minute\": [datetime(2021, 10, 3, 15, 53)],\n            \"floor_to_seconds\": [datetime(2021, 10, 3, 15, 53, 42)],\n            \"floor_to_millisec\": [datetime(2021, 10, 3, 15, 53, 42)],\n        }\n    )\n\n    assert_eq(df, expected_df, check_dtype=False)\n\n    # test exception handling\n    with pytest.raises(NotImplementedError):\n        df = c.sql(\n            \"\"\"\n            SELECT\n                FLOOR(d TO YEAR) as floor_to_year\n            FROM df\n            \"\"\"\n        )\n\n\ndef test_timestampdiff(c):\n    ts_literal1 = datetime(2002, 3, 7, 9, 10, 5, 123)\n    ts_literal2 = datetime(2001, 6, 5, 10, 11, 6, 234)\n    df = dd.from_pandas(\n        pd.DataFrame({\"ts_literal1\": [ts_literal1], \"ts_literal2\": [ts_literal2]}),\n        npartitions=1,\n    )\n    c.create_table(\"df\", df)\n\n    query = \"\"\"\n        SELECT timestampdiff(NANOSECOND, ts_literal1, ts_literal2) as res0,\n        timestampdiff(MICROSECOND, ts_literal1, ts_literal2) as res1,\n        timestampdiff(SECOND, ts_literal1, ts_literal2) as res2,\n        timestampdiff(MINUTE, ts_literal1, ts_literal2) as res3,\n        timestampdiff(HOUR, ts_literal1, ts_literal2) as res4,\n        timestampdiff(DAY, ts_literal1, ts_literal2) as res5,\n        timestampdiff(WEEK, ts_literal1, ts_literal2) as res6,\n        timestampdiff(MONTH, ts_literal1, ts_literal2) as res7,\n        timestampdiff(QUARTER, ts_literal1, ts_literal2) as res8,\n        timestampdiff(YEAR, ts_literal1, ts_literal2) as res9\n        FROM df\n    \"\"\"\n    df = c.sql(query)\n\n    expected_df = pd.DataFrame(\n        {\n            \"res0\": [-23756338999889000],\n            \"res1\": [-23756338999889],\n            \"res2\": [-23756338],\n            \"res3\": [-395938],\n            \"res4\": [-6598],\n            \"res5\": [-274],\n            \"res6\": [-39],\n            \"res7\": [-9],\n            \"res8\": [-3],\n            \"res9\": [0],\n        }\n    )\n\n    assert_eq(df, expected_df, check_dtype=False)\n\n    test = pd.DataFrame(\n        {\n            \"a\": [\n                datetime(2002, 6, 5, 2, 1, 5, 200),\n                datetime(2002, 9, 1),\n                datetime(1970, 12, 3),\n            ],\n            \"b\": [\n                datetime(2002, 6, 7, 1, 0, 2, 100),\n                datetime(2003, 6, 5),\n                datetime(2038, 6, 5),\n            ],\n        }\n    )\n    c.create_table(\"test\", test)\n\n    query = (\n        \"SELECT timestampdiff(NANOSECOND, a, b) as nanoseconds,\"\n        \"timestampdiff(MICROSECOND, a, b) as microseconds,\"\n        \"timestampdiff(SECOND, a, b) as seconds,\"\n        \"timestampdiff(MINUTE, a, b) as minutes,\"\n        \"timestampdiff(HOUR, a, b) as hours,\"\n        \"timestampdiff(DAY, a, b) as days,\"\n        \"timestampdiff(WEEK, a, b) as weeks,\"\n        \"timestampdiff(MONTH, a, b) as months,\"\n        \"timestampdiff(QUARTER, a, b) as quarters,\"\n        \"timestampdiff(YEAR, a, b) as years\"\n        \" FROM test\"\n    )\n    ddf = c.sql(query)\n\n    expected_df = pd.DataFrame(\n        {\n            \"nanoseconds\": [\n                169136999900000,\n                23932800000000000,\n                2130278400000000000,\n            ],\n            \"microseconds\": [169136999900, 23932800000000, 2130278400000000],\n            \"seconds\": [169136, 23932800, 2130278400],\n            \"minutes\": [2818, 398880, 35504640],\n            \"hours\": [46, 6648, 591744],\n            \"days\": [1, 277, 24656],\n            \"weeks\": [0, 39, 3522],\n            \"months\": [0, 9, 810],\n            \"quarters\": [0, 3, 270],\n            \"years\": [0, 0, 67],\n        }\n    )\n\n    assert_eq(ddf, expected_df, check_dtype=False)\n\n\n@pytest.mark.parametrize(\"gpu\", [False, pytest.param(True, marks=pytest.mark.gpu)])\ndef test_totimestamp(c, gpu):\n    df = pd.DataFrame(\n        {\n            \"a\": np.array([1203073300, 1406073600, 2806073600]),\n        }\n    )\n    c.create_table(\"df\", df, gpu=gpu)\n\n    df = c.sql(\n        \"\"\"\n        SELECT to_timestamp(a) AS date FROM df\n    \"\"\"\n    )\n    expected_df = pd.DataFrame(\n        {\n            \"date\": [\n                datetime(2008, 2, 15, 11, 1, 40),\n                datetime(2014, 7, 23),\n                datetime(2058, 12, 2, 16, 53, 20),\n            ],\n        }\n    )\n    assert_eq(df, expected_df, check_dtype=False)\n\n    df = pd.DataFrame(\n        {\n            \"a\": np.array([\"1997-02-28 10:30:00\", \"1997-03-28 10:30:01\"]),\n        }\n    )\n    c.create_table(\"df\", df, gpu=gpu)\n\n    df = c.sql(\n        \"\"\"\n        SELECT to_timestamp(a) AS date FROM df\n    \"\"\"\n    )\n    expected_df = pd.DataFrame(\n        {\n            \"date\": [\n                datetime(1997, 2, 28, 10, 30, 0),\n                datetime(1997, 3, 28, 10, 30, 1),\n            ],\n        }\n    )\n    assert_eq(df, expected_df, check_dtype=False)\n\n    df = pd.DataFrame(\n        {\n            \"a\": np.array([\"02/28/1997\", \"03/28/1997\"]),\n        }\n    )\n    c.create_table(\"df\", df, gpu=gpu)\n\n    df = c.sql(\n        \"\"\"\n        SELECT to_timestamp(a, \"%m/%d/%Y\") AS date FROM df\n    \"\"\"\n    )\n    expected_df = pd.DataFrame(\n        {\n            \"date\": [\n                datetime(1997, 2, 28, 0, 0, 0),\n                datetime(1997, 3, 28, 0, 0, 0),\n            ],\n        }\n    )\n    # https://github.com/rapidsai/cudf/issues/12062\n    if not gpu:\n        assert_eq(df, expected_df, check_dtype=False)\n\n    int_input = 1203073300\n    df = c.sql(f\"SELECT to_timestamp({int_input}) as date\")\n    expected_df = pd.DataFrame(\n        {\n            \"date\": [\n                datetime(2008, 2, 15, 11, 1, 40),\n            ],\n        }\n    )\n    assert_eq(df, expected_df, check_dtype=False)\n\n    string_input = \"1997-02-28 10:30:00\"\n    df = c.sql(f\"SELECT to_timestamp('{string_input}') as date\")\n    expected_df = pd.DataFrame(\n        {\n            \"date\": [\n                datetime(1997, 2, 28, 10, 30, 0),\n            ],\n        }\n    )\n    assert_eq(df, expected_df, check_dtype=False)\n\n    string_input = \"02/28/1997\"\n    df = c.sql(f\"SELECT to_timestamp('{string_input}', '%m/%d/%Y') as date\")\n    expected_df = pd.DataFrame(\n        {\n            \"date\": [\n                datetime(1997, 2, 28, 0, 0, 0),\n            ],\n        }\n    )\n    assert_eq(df, expected_df, check_dtype=False)\n\n\n@pytest.mark.parametrize(\n    \"gpu\",\n    [\n        False,\n        pytest.param(\n            True,\n            marks=(pytest.mark.gpu,),\n        ),\n    ],\n)\ndef test_extract_date(c, gpu):\n    df = pd.DataFrame(\n        {\n            \"a\": [1, 2, 3],\n            \"b\": [4, 5, 6],\n        }\n    )\n    df[\"t\"] = [datetime(2021, 1, 1), datetime(2022, 2, 2), datetime(2023, 3, 3)]\n    c.create_table(\"df\", df, gpu=gpu)\n\n    result = c.sql(\"SELECT EXTRACT(DATE FROM t) AS e FROM df\")\n    expected_df = pd.DataFrame(\n        {\"e\": [datetime(2021, 1, 1), datetime(2022, 2, 2), datetime(2023, 3, 3)]}\n    )\n    assert_eq(result, expected_df)\n\n    result = c.sql(\"SELECT * FROM df WHERE EXTRACT(DATE FROM t) > '2021-02-01'\")\n    expected_df = pd.DataFrame(\n        {\n            \"a\": [2, 3],\n            \"b\": [5, 6],\n            \"t\": [datetime(2022, 2, 2), datetime(2023, 3, 3)],\n        }\n    )\n    assert_eq(result, expected_df, check_index=False)\n\n    result = c.sql(\n        \"SELECT * FROM df WHERE EXTRACT(DATE FROM t) BETWEEN '2020-10-01' AND '2022-10-10'\"\n    )\n    expected_df = pd.DataFrame(\n        {\"a\": [1, 2], \"b\": [4, 5], \"t\": [datetime(2021, 1, 1), datetime(2022, 2, 2)]}\n    )\n    assert_eq(result, expected_df)\n\n    result = c.sql(\"SELECT TIMESTAMPADD(YEAR, 1, EXTRACT(DATE FROM t)) AS ta FROM df\")\n    expected_df = pd.DataFrame(\n        {\"ta\": [datetime(2022, 1, 1), datetime(2023, 2, 2), datetime(2024, 3, 3)]}\n    )\n    assert_eq(result, expected_df)\n\n    result = c.sql(\"SELECT EXTRACT(DATE FROM t) + INTERVAL '2 days' AS i FROM df\")\n    expected_df = pd.DataFrame(\n        {\"i\": [datetime(2021, 1, 3), datetime(2022, 2, 4), datetime(2023, 3, 5)]}\n    )\n    assert_eq(result, expected_df)\n\n\n@pytest.mark.parametrize(\n    \"gpu\",\n    [\n        False,\n        pytest.param(\n            True,\n            marks=(pytest.mark.gpu,),\n        ),\n    ],\n)\ndef test_scalar_timestamps(c, gpu):\n    df = pd.DataFrame({\"d\": [1203073300, 1503073700]})\n    c.create_table(\"df\", df, gpu=gpu)\n\n    expected_df = pd.DataFrame(\n        {\n            \"dt\": [datetime(2008, 2, 20, 11, 1, 40), datetime(2017, 8, 23, 16, 28, 20)],\n        }\n    )\n\n    df1 = c.sql(\"SELECT to_timestamp(d) + INTERVAL '5 days' AS dt FROM df\")\n    assert_eq(df1, expected_df)\n    df2 = c.sql(\"SELECT CAST(d AS TIMESTAMP) + INTERVAL '5 days' AS dt FROM df\")\n    assert_eq(df2, expected_df)\n\n    df1 = c.sql(\"SELECT TIMESTAMPADD(DAY, 5, to_timestamp(d)) AS dt FROM df\")\n    assert_eq(df1, expected_df)\n    df2 = c.sql(\"SELECT TIMESTAMPADD(DAY, 5, d) AS dt FROM df\")\n    assert_eq(df2, expected_df)\n    df3 = c.sql(\"SELECT TIMESTAMPADD(DAY, 5, CAST(d AS TIMESTAMP)) AS dt FROM df\")\n    assert_eq(df3, expected_df)\n\n    expected_df = pd.DataFrame({\"day\": [15, 18]})\n    df1 = c.sql(\"SELECT EXTRACT(DAY FROM to_timestamp(d)) AS day FROM df\")\n    assert_eq(df1, expected_df, check_dtype=False)\n    df2 = c.sql(\"SELECT EXTRACT(DAY FROM CAST(d AS TIMESTAMP)) AS day FROM df\")\n    assert_eq(df2, expected_df, check_dtype=False)\n\n    expected_df = pd.DataFrame(\n        {\n            \"ceil_to_day\": [datetime(2008, 2, 16), datetime(2017, 8, 19)],\n        }\n    )\n    df1 = c.sql(\"SELECT CEIL(to_timestamp(d) TO DAY) AS ceil_to_day FROM df\")\n    assert_eq(df1, expected_df, check_dtype=(not gpu))\n    df2 = c.sql(\"SELECT CEIL(CAST(d AS TIMESTAMP) TO DAY) AS ceil_to_day FROM df\")\n    assert_eq(df2, expected_df)\n\n    expected_df = pd.DataFrame(\n        {\n            \"floor_to_day\": [datetime(2008, 2, 15), datetime(2017, 8, 18)],\n        }\n    )\n    df1 = c.sql(\"SELECT FLOOR(to_timestamp(d) TO DAY) AS floor_to_day FROM df\")\n    assert_eq(df1, expected_df, check_dtype=(not gpu))\n    df2 = c.sql(\"SELECT FLOOR(CAST(d AS TIMESTAMP) TO DAY) AS floor_to_day FROM df\")\n    assert_eq(df2, expected_df)\n\n    df = pd.DataFrame({\"d1\": [1203073300], \"d2\": [1503073700]})\n    c.create_table(\"df\", df, gpu=gpu)\n    expected_df = pd.DataFrame({\"dt\": [3472]})\n    df1 = c.sql(\n        \"SELECT TIMESTAMPDIFF(DAY, to_timestamp(d1), to_timestamp(d2)) AS dt FROM df\"\n    )\n    # TODO: The GPU case returns an incorrect value here\n    if not gpu:\n        assert_eq(df1, expected_df)\n    df2 = c.sql(\"SELECT TIMESTAMPDIFF(DAY, d1, d2) AS dt FROM df\")\n    assert_eq(df2, expected_df, check_dtype=False)\n    df3 = c.sql(\n        \"SELECT TIMESTAMPDIFF(DAY, CAST(d1 AS TIMESTAMP), CAST(d2 AS TIMESTAMP)) AS dt FROM df\"\n    )\n    assert_eq(df3, expected_df)\n\n    scalar1 = 1203073300\n    scalar2 = 1503073700\n\n    expected_df = pd.DataFrame({\"dt\": [datetime(2008, 2, 20, 11, 1, 40)]})\n\n    df1 = c.sql(f\"SELECT to_timestamp({scalar1}) + INTERVAL '5 days' AS dt\")\n    assert_eq(df1, expected_df)\n    # TODO: Fix seconds/nanoseconds conversion\n    # df2 = c.sql(f\"SELECT CAST({scalar1} AS TIMESTAMP) + INTERVAL '5 days' AS dt\")\n    # assert_eq(df2, expected_df)\n\n    df1 = c.sql(f\"SELECT TIMESTAMPADD(DAY, 5, to_timestamp({scalar1})) AS dt\")\n    assert_eq(df1, expected_df)\n    df2 = c.sql(f\"SELECT TIMESTAMPADD(DAY, 5, {scalar1}) AS dt\")\n    assert_eq(df2, expected_df)\n    df3 = c.sql(f\"SELECT TIMESTAMPADD(DAY, 5, CAST({scalar1} AS TIMESTAMP)) AS dt\")\n    assert_eq(df3, expected_df)\n\n    expected_df = pd.DataFrame({\"day\": [15]})\n    df1 = c.sql(f\"SELECT EXTRACT(DAY FROM to_timestamp({scalar1})) AS day\")\n    assert_eq(df1, expected_df, check_dtype=False)\n    # TODO: Fix seconds/nanoseconds conversion\n    # df2 = c.sql(f\"SELECT EXTRACT(DAY FROM CAST({scalar1} AS TIMESTAMP)) AS day\")\n    # assert_eq(df2, expected_df, check_dtype=False)\n\n    expected_df = pd.DataFrame({\"ceil_to_day\": [datetime(2008, 2, 16)]})\n    df1 = c.sql(f\"SELECT CEIL(to_timestamp({scalar1}) TO DAY) AS ceil_to_day\")\n    assert_eq(df1, expected_df)\n    df2 = c.sql(f\"SELECT CEIL(CAST({scalar1} AS TIMESTAMP) TO DAY) AS ceil_to_day\")\n    assert_eq(df2, expected_df)\n\n    expected_df = pd.DataFrame({\"floor_to_day\": [datetime(2008, 2, 15)]})\n    df1 = c.sql(f\"SELECT FLOOR(to_timestamp({scalar1}) TO DAY) AS floor_to_day\")\n    assert_eq(df1, expected_df)\n    df2 = c.sql(f\"SELECT FLOOR(CAST({scalar1} AS TIMESTAMP) TO DAY) AS floor_to_day\")\n    assert_eq(df2, expected_df)\n\n    expected_df = pd.DataFrame({\"dt\": [3472]})\n    df1 = c.sql(\n        f\"SELECT TIMESTAMPDIFF(DAY, to_timestamp({scalar1}), to_timestamp({scalar2})) AS dt\"\n    )\n    assert_eq(df1, expected_df)\n    df2 = c.sql(f\"SELECT TIMESTAMPDIFF(DAY, {scalar1}, {scalar2}) AS dt\")\n    assert_eq(df2, expected_df, check_dtype=False)\n    df3 = c.sql(\n        f\"SELECT TIMESTAMPDIFF(DAY, CAST({scalar1} AS TIMESTAMP), CAST({scalar2} AS TIMESTAMP)) AS dt\"\n    )\n    assert_eq(df3, expected_df)\n\n\ndef test_datetime_coercion(c):\n    d_table = pd.DataFrame(\n        {\n            \"d_date\": [\n                datetime(2023, 7, 1),\n                datetime(2023, 7, 5),\n                datetime(2023, 7, 10),\n                datetime(2023, 7, 15),\n            ],\n            \"x\": [1, 2, 3, 4],\n        }\n    )\n    c.create_table(\"d_table\", d_table)\n\n    df = c.sql(\n        \"\"\"\n        SELECT * FROM d_table d1, d_table d2\n        WHERE d2.x < d1.x + (1 + 2)\n        AND d2.d_date > d1.d_date + (2 + 3)\n    \"\"\"\n    )\n    expected_df = c.sql(\n        \"\"\"\n        SELECT * FROM d_table d1, d_table d2\n        WHERE d2.x < d1.x + (1 + 2)\n        AND d2.d_date > d1.d_date + INTERVAL '5 days'\n    \"\"\"\n    )\n    assert_eq(df, expected_df)\n"
  },
  {
    "path": "tests/integration/test_sample.py",
    "content": "import numpy as np\nimport pytest\n\nfrom tests.utils import assert_eq\n\n\ndef get_system_sample(df, fraction, seed):\n    random_state = np.random.RandomState(seed)\n    random_choice = random_state.choice(\n        [True, False],\n        size=df.npartitions,\n        replace=True,\n        p=[fraction, 1 - fraction],\n    )\n\n    if random_choice.any():\n        df = df.partitions[random_choice]\n    else:\n        df = df.head(0, compute=False)\n\n    return df\n\n\n@pytest.mark.xfail(reason=\"WIP DataFusion\")\ndef test_sample(c, df):\n    ddf = c.sql(\"SELECT * FROM df\")\n\n    # fixed system samples\n    assert_eq(\n        c.sql(\"SELECT * FROM df TABLESAMPLE SYSTEM (20) REPEATABLE (10)\"),\n        get_system_sample(ddf, 0.20, 10),\n    )\n    assert_eq(\n        c.sql(\"SELECT * FROM df TABLESAMPLE SYSTEM (20) REPEATABLE (11)\"),\n        get_system_sample(ddf, 0.20, 11),\n    )\n    assert_eq(\n        c.sql(\"SELECT * FROM df TABLESAMPLE SYSTEM (50) REPEATABLE (10)\"),\n        get_system_sample(ddf, 0.50, 10),\n    )\n    assert_eq(\n        c.sql(\"SELECT * FROM df TABLESAMPLE SYSTEM (0.001) REPEATABLE (10)\"),\n        get_system_sample(ddf, 0.00001, 10),\n    )\n    assert_eq(\n        c.sql(\"SELECT * FROM df TABLESAMPLE SYSTEM (99.999) REPEATABLE (10)\"),\n        get_system_sample(ddf, 0.99999, 10),\n    )\n\n    # fixed bernoulli samples\n    assert_eq(\n        c.sql(\"SELECT * FROM df TABLESAMPLE BERNOULLI (50) REPEATABLE (10)\"),\n        ddf.sample(frac=0.50, replace=False, random_state=10),\n    )\n    assert_eq(\n        c.sql(\"SELECT * FROM df TABLESAMPLE BERNOULLI (70) REPEATABLE (10)\"),\n        ddf.sample(frac=0.70, replace=False, random_state=10),\n    )\n    assert_eq(\n        c.sql(\"SELECT * FROM df TABLESAMPLE BERNOULLI (0.001) REPEATABLE (10)\"),\n        ddf.sample(frac=0.00001, replace=False, random_state=10),\n    )\n    assert_eq(\n        c.sql(\"SELECT * FROM df TABLESAMPLE BERNOULLI (99.999) REPEATABLE (10)\"),\n        ddf.sample(frac=0.99999, replace=False, random_state=10),\n    )\n\n    # variable samples, can only check boundaries\n    return_df = c.sql(\"SELECT * FROM df TABLESAMPLE BERNOULLI (50)\")\n    assert len(return_df) >= 0 and len(return_df) <= len(df)\n\n    return_df = c.sql(\"SELECT * FROM df TABLESAMPLE SYSTEM (50)\")\n    assert len(return_df) >= 0 and len(return_df) <= len(df)\n"
  },
  {
    "path": "tests/integration/test_schema.py",
    "content": "import dask.dataframe as dd\nimport numpy as np\nimport pytest\n\nfrom dask_sql.utils import ParsingException\nfrom tests.utils import assert_eq\n\n\n@pytest.mark.xfail(reason=\"WIP DataFusion\")\ndef test_table_schema(c, df):\n    original_df = c.sql(\"SELECT * FROM df\")\n\n    assert_eq(original_df, c.sql(\"SELECT * FROM root.df\"))\n\n    c.sql(\"CREATE SCHEMA foo\")\n    assert_eq(original_df, c.sql(\"SELECT * FROM df\"))\n\n    c.sql('USE SCHEMA \"foo\"')\n    assert_eq(original_df, c.sql(\"SELECT * FROM root.df\"))\n\n    c.sql(\"CREATE TABLE bar AS TABLE root.df\")\n    assert_eq(original_df, c.sql(\"SELECT * FROM bar\"))\n\n    with pytest.raises(KeyError):\n        c.sql(\"CREATE TABLE other.bar AS TABLE df\")\n\n    c.sql('USE SCHEMA \"root\"')\n    assert_eq(original_df, c.sql(\"SELECT * FROM foo.bar\"))\n\n    with pytest.raises(ParsingException):\n        c.sql(\"SELECT * FROM bar\")\n\n    c.sql(\"DROP SCHEMA foo\")\n\n    with pytest.raises(ParsingException):\n        c.sql(\"SELECT * FROM foo.bar\")\n\n\n@pytest.mark.xfail(reason=\"WIP DataFusion\")\ndef test_function(c):\n    c.sql(\"CREATE SCHEMA other\")\n    c.sql(\"USE SCHEMA root\")\n\n    def f(x):\n        return x**2\n\n    c.register_function(f, \"f\", [(\"x\", np.float64)], np.float64, schema_name=\"other\")\n\n    with pytest.raises(ParsingException):\n        c.sql(\"SELECT F(a) AS a FROM df\")\n\n    c.sql(\"SELECT other.F(a) AS a FROM df\")\n\n    c.sql(\"USE SCHEMA other\")\n    c.sql(\"SELECT F(a) AS a FROM root.df\")\n\n    c.sql(\"USE SCHEMA root\")\n    fagg = dd.Aggregation(\"f\", lambda x: x.sum(), lambda x: x.sum())\n    c.register_aggregation(\n        fagg, \"fagg\", [(\"x\", np.float64)], np.float64, schema_name=\"other\"\n    )\n\n    with pytest.raises(ParsingException):\n        c.sql(\"SELECT FAGG(b) AS test FROM df\")\n\n    c.sql(\"SELECT other.FAGG(b) AS test FROM df\")\n\n    c.sql(\"USE SCHEMA other\")\n    c.sql(\"SELECT FAGG(b) AS test FROM root.df\")\n\n\ndef test_create_schema(c):\n    c.sql(\"CREATE SCHEMA new_schema\")\n    assert \"new_schema\" in c.schema\n\n    with pytest.raises(RuntimeError):\n        c.sql(\"CREATE SCHEMA new_schema\")\n\n    c.sql(\"CREATE OR REPLACE SCHEMA new_schema\")\n    c.sql(\"CREATE SCHEMA IF NOT EXISTS new_schema\")\n\n\ndef test_drop_schema(c):\n    with pytest.raises(RuntimeError):\n        c.sql(\"DROP SCHEMA new_schema\")\n\n    c.sql(\"DROP SCHEMA IF EXISTS new_schema\")\n\n    c.sql(\"CREATE SCHEMA new_schema\")\n    c.sql(\"DROP SCHEMA IF EXISTS new_schema\")\n\n    with pytest.raises(RuntimeError):\n        c.sql(\"USE SCHEMA new_schema\")\n\n    with pytest.raises(RuntimeError):\n        c.sql(\"DROP SCHEMA root\")\n\n    c.sql(\"CREATE SCHEMA example\")\n    c.sql(\"USE SCHEMA example\")\n    c.sql(\"DROP SCHEMA example\")\n    assert c.schema_name == c.DEFAULT_SCHEMA_NAME\n    assert \"example\" not in c.schema\n"
  },
  {
    "path": "tests/integration/test_select.py",
    "content": "import numpy as np\nimport pandas as pd\nimport pytest\nfrom dask.dataframe.optimize import optimize_dataframe_getitem\nfrom dask.utils_test import hlg_layer\n\nfrom dask_sql.utils import ParsingException\nfrom tests.utils import assert_eq, skipif_dask_expr_enabled\n\n\ndef test_select(c, df):\n    result_df = c.sql(\"SELECT * FROM df\")\n\n    assert_eq(result_df, df)\n\n\ndef test_select_alias(c, df):\n    result_df = c.sql(\"SELECT a as b, b as a FROM df\")\n\n    expected_df = pd.DataFrame(index=df.index)\n    expected_df[\"b\"] = df.a\n    expected_df[\"a\"] = df.b\n\n    assert_eq(result_df[[\"a\", \"b\"]], expected_df[[\"a\", \"b\"]])\n\n\ndef test_select_column(c, df):\n    result_df = c.sql(\"SELECT a FROM df\")\n\n    assert_eq(result_df, df[[\"a\"]])\n\n\ndef test_select_different_types(c):\n    expected_df = pd.DataFrame(\n        {\n            \"date\": pd.to_datetime(\n                [\"2022-01-21 17:34\", \"2022-01-21\", \"17:34\", pd.NaT],\n                format=\"mixed\",\n            ),\n            \"string\": [\"this is a test\", \"another test\", \"äölüć\", \"\"],\n            \"integer\": [1, 2, -4, 5],\n            \"float\": [-1.1, np.NaN, pd.NA, np.sqrt(2)],\n        }\n    )\n    c.create_table(\"df\", expected_df)\n    result_df = c.sql(\n        \"\"\"\n    SELECT *\n    FROM df\n    \"\"\"\n    )\n\n    assert_eq(result_df, expected_df)\n\n\ndef test_select_expr(c, df):\n    result_df = c.sql(\"SELECT a + 1 AS a, b AS bla, a - 1 FROM df\")\n    result_df = result_df\n\n    expected_df = pd.DataFrame(\n        {\n            \"a\": df[\"a\"] + 1,\n            \"bla\": df[\"b\"],\n            \"df.a - Int64(1)\": df[\"a\"] - 1,\n        }\n    )\n    assert_eq(result_df, expected_df)\n\n\ndef test_select_of_select(c, df):\n    result_df = c.sql(\n        \"\"\"\n        SELECT 2*c AS e, d - 1 AS f\n        FROM\n        (\n            SELECT a - 1 AS c, 2*b  AS d\n            FROM df\n        ) AS \"inner\"\n        \"\"\"\n    )\n    expected_df = pd.DataFrame({\"e\": 2 * (df[\"a\"] - 1), \"f\": 2 * df[\"b\"] - 1})\n    assert_eq(result_df, expected_df)\n\n\n@pytest.mark.xfail(\n    reason=\"Column casing doesn't work as expected with datafusion>21, \"\n    \"https://github.com/apache/arrow-datafusion/issues/5626\"\n)\ndef test_select_of_select_with_casing(c, df):\n    result_df = c.sql(\n        \"\"\"\n        SELECT \"AAA\", \"aaa\", \"aAa\"\n        FROM\n        (\n            SELECT a - 1 AS \"aAa\", 2*b AS \"aaa\", a + b AS \"AAA\"\n            FROM df\n        ) AS \"inner\"\n        \"\"\"\n    )\n\n    expected_df = pd.DataFrame(\n        {\"AAA\": df[\"a\"] + df[\"b\"], \"aaa\": 2 * df[\"b\"], \"aAa\": df[\"a\"] - 1}\n    )\n\n    assert_eq(result_df, expected_df)\n\n\ndef test_wrong_input(c):\n    with pytest.raises(ParsingException):\n        c.sql(\"\"\"SELECT x FROM df\"\"\")\n\n    with pytest.raises(ParsingException):\n        c.sql(\"\"\"SELECT x FROM df\"\"\")\n\n\ndef test_timezones(c, datetime_table):\n    result_df = c.sql(\n        \"\"\"\n        SELECT * FROM datetime_table\n        \"\"\"\n    )\n\n    assert_eq(result_df, datetime_table)\n\n\n@pytest.mark.parametrize(\n    \"input_table\",\n    [\n        \"long_table\",\n        pytest.param(\"gpu_long_table\", marks=pytest.mark.gpu),\n    ],\n)\n@pytest.mark.parametrize(\n    \"limit,offset\",\n    [(101, 0), (200, 0), (100, 0), (100, 99), (100, 100), (101, 101), (0, 101)],\n)\ndef test_limit(c, input_table, limit, offset, request):\n    long_table = request.getfixturevalue(input_table)\n\n    if not limit:\n        query = f\"SELECT * FROM long_table OFFSET {offset}\"\n    else:\n        query = f\"SELECT * FROM long_table LIMIT {limit} OFFSET {offset}\"\n\n    assert_eq(c.sql(query), long_table.iloc[offset : offset + limit if limit else None])\n\n\n@pytest.mark.parametrize(\n    \"input_table\",\n    [\n        \"datetime_table\",\n        pytest.param(\"gpu_datetime_table\", marks=pytest.mark.gpu),\n    ],\n)\ndef test_date_casting(c, input_table, request):\n    datetime_table = request.getfixturevalue(input_table)\n    result_df = c.sql(\n        f\"\"\"\n        SELECT\n            CAST(timezone AS DATE) AS timezone,\n            CAST(no_timezone AS DATE) AS no_timezone,\n            CAST(utc_timezone AS DATE) AS utc_timezone\n        FROM {input_table}\n        \"\"\"\n    )\n\n    expected_df = datetime_table\n    expected_df[\"timezone\"] = (\n        expected_df[\"timezone\"].dt.tz_localize(None).dt.floor(\"D\").astype(\"<M8[ns]\")\n    )\n    expected_df[\"no_timezone\"] = (\n        expected_df[\"no_timezone\"].astype(\"<M8[ns]\").dt.floor(\"D\").astype(\"<M8[ns]\")\n    )\n    expected_df[\"utc_timezone\"] = (\n        expected_df[\"utc_timezone\"].dt.tz_localize(None).dt.floor(\"D\").astype(\"<M8[ns]\")\n    )\n\n    assert_eq(result_df, expected_df)\n\n\n@pytest.mark.parametrize(\n    \"input_table\",\n    [\n        \"datetime_table\",\n        pytest.param(\"gpu_datetime_table\", marks=pytest.mark.gpu),\n    ],\n)\ndef test_timestamp_casting(c, input_table, request):\n    datetime_table = request.getfixturevalue(input_table)\n    result_df = c.sql(\n        f\"\"\"\n        SELECT\n            CAST(timezone AS TIMESTAMP) AS timezone,\n            CAST(no_timezone AS TIMESTAMP) AS no_timezone,\n            CAST(utc_timezone AS TIMESTAMP) AS utc_timezone\n        FROM {input_table}\n        \"\"\"\n    )\n\n    expected_df = datetime_table\n    expected_df[\"timezone\"] = expected_df[\"timezone\"].dt.tz_localize(None)\n    expected_df[\"utc_timezone\"] = expected_df[\"utc_timezone\"].dt.tz_localize(None)\n\n    assert_eq(result_df, expected_df)\n\n\ndef test_multi_case_when(c):\n    df = pd.DataFrame({\"a\": [1, 6, 7, 8, 9]})\n    c.create_table(\"df\", df)\n\n    actual_df = c.sql(\n        \"\"\"\n    SELECT\n        CASE WHEN a BETWEEN 6 AND 8 THEN 1 ELSE 0 END AS \"C\"\n    FROM df\n    \"\"\"\n    )\n    expected_df = pd.DataFrame({\"C\": [0, 1, 1, 1, 0]})\n\n    # dtype varies between int32/int64 depending on pandas version\n    assert_eq(actual_df, expected_df, check_dtype=False)\n\n\ndef test_case_when_no_else(c):\n    df = pd.DataFrame({\"a\": [1, 6, 7, 8, 9]})\n    c.create_table(\"df\", df)\n\n    actual_df = c.sql(\n        \"\"\"\n    SELECT\n        CASE WHEN a BETWEEN 6 AND 8 THEN 1 END AS \"C\"\n    FROM df\n    \"\"\"\n    )\n    expected_df = pd.DataFrame({\"C\": [None, 1, 1, 1, None]})\n\n    # dtype varies between float64/object depending on pandas version\n    assert_eq(actual_df, expected_df, check_dtype=False)\n\n\ndef test_singular_column_selection(c):\n    df = pd.DataFrame({\"a\": [1, 2, 3], \"b\": [4, 5, 6]})\n    c.create_table(\"df\", df)\n\n    wildcard_result = c.sql(\"SELECT * from df\")\n    single_col_result = c.sql(\"SELECT b from df\")\n\n    assert_eq(wildcard_result[\"b\"], single_col_result[\"b\"])\n\n\n@pytest.mark.parametrize(\n    \"input_cols\",\n    [\n        [\"a\"],\n        [\"a\", \"b\"],\n        [\"a\", \"d\"],\n        [\"d\", \"a\"],\n        [\"a\", \"b\", \"d\"],\n    ],\n)\n@skipif_dask_expr_enabled()\ndef test_multiple_column_projection(c, parquet_ddf, input_cols):\n    projection_list = \", \".join(input_cols)\n    result_df = c.sql(f\"SELECT {projection_list} from parquet_ddf\")\n\n    # There are 5 columns in the table, ensure only specified ones are read\n    assert_eq(len(result_df.columns), len(input_cols))\n    assert_eq(parquet_ddf[input_cols], result_df)\n    assert sorted(\n        hlg_layer(\n            optimize_dataframe_getitem(result_df.dask, result_df.__dask_keys__()),\n            \"read-parquet\",\n        ).columns\n    ) == sorted(input_cols)\n\n\ndef test_wildcard_select(c):\n    result_df = c.sql(\"SELECT COUNT(*) FROM df\")\n\n    expected_df = pd.DataFrame(\n        {\n            \"COUNT(*)\": [700],\n        }\n    )\n\n    assert_eq(result_df, expected_df)\n"
  },
  {
    "path": "tests/integration/test_server.py",
    "content": "from time import sleep\n\nimport pandas as pd\nimport pytest\n\nfrom dask_sql import Context\nfrom dask_sql.server.app import _init_app, app\nfrom tests.integration.fixtures import DISTRIBUTED_TESTS\n\n# needed for the testclient\npytest.importorskip(\"requests\")\n\n\n@pytest.fixture(scope=\"module\")\ndef app_client():\n    c = Context()\n    c.sql(\"SELECT 1 + 1\").compute()\n    _init_app(app, c)\n\n    # late import for the importskip\n    from fastapi.testclient import TestClient\n\n    yield TestClient(app)\n\n    # avoid closing client it's session-wide\n    if not DISTRIBUTED_TESTS:\n        app.client.close()\n\n\ndef get_result_or_error(app_client, response):\n    result = response.json()\n\n    assert \"nextUri\" in result\n    assert \"error\" not in result\n\n    status_url = result[\"nextUri\"]\n    next_url = status_url\n\n    counter = 0\n    while True:\n        response = app_client.get(next_url)\n        assert response.status_code == 200\n\n        result = response.json()\n\n        if \"nextUri\" not in result:\n            break\n\n        next_url = result[\"nextUri\"]\n\n        counter += 1\n        assert counter <= 100\n\n        sleep(0.1)\n\n    return result\n\n\ndef test_routes(app_client):\n    assert app_client.post(\"/v1/statement\", data=\"SELECT 1 + 1\").status_code == 200\n    assert app_client.get(\"/v1/statement\").status_code == 405\n    assert app_client.get(\"/v1/empty\").status_code == 200\n    assert app_client.get(\"/v1/status/some-wrong-uuid\").status_code == 404\n    assert app_client.delete(\"/v1/cancel/some-wrong-uuid\").status_code == 404\n    assert app_client.get(\"/v1/cancel/some-wrong-uuid\").status_code == 405\n\n\ndef test_sql_query_cancel(app_client):\n    response = app_client.post(\"/v1/statement\", data=\"SELECT 1 + 1\")\n    assert response.status_code == 200\n\n    cancel_url = response.json()[\"partialCancelUri\"]\n\n    response = app_client.delete(cancel_url)\n    assert response.status_code == 200\n\n    response = app_client.delete(cancel_url)\n    assert response.status_code == 404\n\n\ndef test_sql_query(app_client):\n    response = app_client.post(\"/v1/statement\", data=\"SELECT 1 + 1\")\n    assert response.status_code == 200\n\n    result = get_result_or_error(app_client, response)\n\n    assert \"columns\" in result\n    assert \"data\" in result\n    assert \"error\" not in result\n    assert \"nextUri\" not in result\n\n    assert result[\"columns\"] == [\n        {\n            \"name\": \"Int64(1) + Int64(1)\",\n            \"type\": \"bigint\",\n            \"typeSignature\": {\"rawType\": \"bigint\", \"arguments\": []},\n        }\n    ]\n    assert result[\"data\"] == [[2]]\n\n\ndef test_wrong_sql_query(app_client):\n    response = app_client.post(\"/v1/statement\", data=\"SELECT 1 + \")\n    assert response.status_code == 200\n\n    result = response.json()\n\n    assert \"columns\" not in result\n    assert \"data\" not in result\n    assert \"error\" in result\n    assert \"message\" in result[\"error\"]\n    # FIXME: ParserErrors currently don't contain information on where the syntax error occurred\n    # assert \"errorLocation\" in result[\"error\"]\n    # assert result[\"error\"][\"errorLocation\"] == {\n    #     \"lineNumber\": 1,\n    #     \"columnNumber\": 10,\n    # }\n\n\ndef test_add_and_query(app_client, df, temporary_data_file):\n    df.to_csv(temporary_data_file, index=False)\n\n    response = app_client.post(\n        \"/v1/statement\",\n        data=f\"\"\"\n        CREATE TABLE\n            new_table\n        WITH (\n            location = '{temporary_data_file}',\n            format = 'csv'\n        )\n    \"\"\",\n    )\n    result = response.json()\n    assert \"error\" not in result\n    assert response.status_code == 200\n\n    response = app_client.post(\"/v1/statement\", data=\"SELECT * FROM new_table\")\n    assert response.status_code == 200\n\n    result = get_result_or_error(app_client, response)\n\n    assert \"columns\" in result\n    assert \"data\" in result\n    assert result[\"columns\"] == [\n        {\n            \"name\": \"a\",\n            \"type\": \"double\",\n            \"typeSignature\": {\"rawType\": \"double\", \"arguments\": []},\n        },\n        {\n            \"name\": \"b\",\n            \"type\": \"double\",\n            \"typeSignature\": {\"rawType\": \"double\", \"arguments\": []},\n        },\n    ]\n\n    assert len(result[\"data\"]) == 700\n    assert \"error\" not in result\n\n\ndef test_register_and_query(app_client, df):\n    df[\"a\"] = df[\"a\"].astype(\"UInt8\")\n    app_client.app.c.create_table(\"new_table\", df)\n\n    response = app_client.post(\"/v1/statement\", data=\"SELECT * FROM new_table\")\n    assert response.status_code == 200\n\n    result = get_result_or_error(app_client, response)\n\n    assert \"columns\" in result\n    assert \"data\" in result\n    assert result[\"columns\"] == [\n        {\n            \"name\": \"a\",\n            \"type\": \"tinyint\",\n            \"typeSignature\": {\"rawType\": \"tinyint\", \"arguments\": []},\n        },\n        {\n            \"name\": \"b\",\n            \"type\": \"double\",\n            \"typeSignature\": {\"rawType\": \"double\", \"arguments\": []},\n        },\n    ]\n\n    assert len(result[\"data\"]) == 700\n    assert \"error\" not in result\n\n\ndef test_inf_table(app_client, user_table_inf):\n    app_client.app.c.create_table(\"new_table\", user_table_inf)\n\n    response = app_client.post(\"/v1/statement\", data=\"SELECT * FROM new_table\")\n    assert response.status_code == 200\n\n    result = get_result_or_error(app_client, response)\n\n    assert \"columns\" in result\n    assert \"data\" in result\n    assert result[\"columns\"] == [\n        {\n            \"name\": \"c\",\n            \"type\": \"double\",\n            \"typeSignature\": {\"rawType\": \"double\", \"arguments\": []},\n        }\n    ]\n\n    assert len(result[\"data\"]) == 3\n    assert result[\"data\"][1] == [\"+Infinity\"]\n    assert \"error\" not in result\n\n\ndef test_nullable_int_table(app_client):\n    app_client.app.c.create_table(\n        \"null_table\", pd.DataFrame({\"a\": [None]}, dtype=\"Int64\")\n    )\n\n    response = app_client.post(\"/v1/statement\", data=\"SELECT * FROM null_table\")\n    assert response.status_code == 200\n\n    result = get_result_or_error(app_client, response)\n\n    assert \"columns\" in result\n    assert \"data\" in result\n    assert result[\"columns\"] == [\n        {\n            \"name\": \"a\",\n            \"type\": \"bigint\",\n            \"typeSignature\": {\"rawType\": \"bigint\", \"arguments\": []},\n        }\n    ]\n\n    assert len(result[\"data\"]) == 1\n    assert result[\"data\"][0] == [None]\n    assert \"error\" not in result\n"
  },
  {
    "path": "tests/integration/test_show.py",
    "content": "import pandas as pd\nimport pytest\n\nfrom dask_sql import Context\nfrom dask_sql.utils import ParsingException\nfrom tests.utils import assert_eq\n\n\ndef test_schemas(c):\n    expected_df = pd.DataFrame({\"Schema\": [c.schema_name, \"information_schema\"]})\n\n    assert_eq(c.sql(\"SHOW SCHEMAS\"), expected_df)\n    assert_eq(c.sql(f\"SHOW SCHEMAS FROM {c.catalog_name}\"), expected_df)\n\n    expected_df = pd.DataFrame({\"Schema\": [\"information_schema\"]})\n\n    assert_eq(\n        c.sql(\"SHOW SCHEMAS LIKE 'information_schema'\"), expected_df, check_index=False\n    )\n    assert_eq(\n        c.sql(f\"SHOW SCHEMAS FROM {c.catalog_name} LIKE 'information_schema'\"),\n        expected_df,\n        check_index=False,\n    )\n\n\n@pytest.mark.parametrize(\"gpu\", [False, pytest.param(True, marks=pytest.mark.gpu)])\ndef test_tables(gpu):\n    c = Context()\n    c.create_table(\"table\", pd.DataFrame(), gpu=gpu)\n\n    expected_df = pd.DataFrame({\"Table\": [\"table\"]})\n\n    assert_eq(\n        c.sql(f'SHOW TABLES FROM \"{c.schema_name}\"'), expected_df, check_index=False\n    )\n    assert_eq(\n        c.sql(f'SHOW TABLES FROM \"{c.catalog_name}\".\"{c.schema_name}\"'),\n        expected_df,\n        check_index=False,\n    )\n\n\ndef test_columns(c):\n    result_df = c.sql(f'SHOW COLUMNS FROM \"{c.schema_name}\".\"user_table_1\"')\n    expected_df = pd.DataFrame(\n        {\n            \"Column\": [\n                \"user_id\",\n                \"b\",\n            ],\n            \"Type\": [\"bigint\", \"bigint\"],\n            \"Extra\": [\"\"] * 2,\n            \"Comment\": [\"\"] * 2,\n        }\n    )\n\n    assert_eq(result_df, expected_df)\n\n    result_df = c.sql('SHOW COLUMNS FROM \"user_table_1\"')\n\n    assert_eq(result_df, expected_df)\n\n\ndef test_wrong_input(c):\n    with pytest.raises(KeyError):\n        c.sql('SHOW COLUMNS FROM \"wrong\".\"table\"')\n    with pytest.raises(ParsingException):\n        c.sql('SHOW COLUMNS FROM \"wrong\".\"table\".\"column\"')\n    with pytest.raises(KeyError):\n        c.sql(f'SHOW COLUMNS FROM \"{c.schema_name}\".\"table\"')\n    with pytest.raises(AttributeError):\n        c.sql('SHOW TABLES FROM \"wrong\"')\n    with pytest.raises(RuntimeError):\n        c.sql(f'SHOW TABLES FROM \"wrong\".\"{c.schema_name}\"')\n    with pytest.raises(RuntimeError):\n        c.sql('SHOW SCHEMAS FROM \"wrong\"')\n\n\ndef test_show_tables(c):\n    c = Context()\n\n    df = pd.DataFrame({\"id\": [0, 1]})\n    c.create_table(\"test\", df)\n\n    expected_df = pd.DataFrame({\"Table\": [\"test\"]})\n\n    # no schema specified\n    assert_eq(c.sql(\"show tables\"), expected_df)\n\n    # unqualified schema\n    assert_eq(c.sql(\"show tables from root\"), expected_df)\n\n    # qualified schema\n    assert_eq(c.sql(\"show tables from dask_sql.root\"), expected_df)\n"
  },
  {
    "path": "tests/integration/test_sort.py",
    "content": "import dask.dataframe as dd\nimport pandas as pd\nimport pytest\n\nfrom dask_sql.context import Context\nfrom tests.utils import assert_eq\n\n\n@pytest.mark.parametrize(\n    \"input_table_1,input_df\",\n    [\n        (\"user_table_1\", \"df\"),\n        pytest.param(\"gpu_user_table_1\", \"gpu_df\", marks=pytest.mark.gpu),\n    ],\n)\ndef test_sort(c, input_table_1, input_df, request):\n    user_table_1 = request.getfixturevalue(input_table_1)\n    df = request.getfixturevalue(input_df)\n\n    df_result = c.sql(\n        f\"\"\"\n    SELECT\n        *\n    FROM {input_table_1}\n    ORDER BY b, user_id DESC\n    \"\"\"\n    )\n    df_expected = user_table_1.sort_values([\"b\", \"user_id\"], ascending=[True, False])\n\n    assert_eq(df_result, df_expected, check_index=False)\n\n    df_result = c.sql(\n        f\"\"\"\n    SELECT\n        *\n    FROM {input_df}\n    ORDER BY b DESC, a DESC\n    \"\"\"\n    )\n    df_expected = df.sort_values([\"b\", \"a\"], ascending=[False, False])\n\n    assert_eq(df_result, df_expected, check_index=False)\n\n    df_result = c.sql(\n        f\"\"\"\n    SELECT\n        *\n    FROM {input_df}\n    ORDER BY a DESC, b\n    \"\"\"\n    )\n    df_expected = df.sort_values([\"a\", \"b\"], ascending=[False, True])\n\n    assert_eq(df_result, df_expected, check_index=False)\n\n    df_result = c.sql(\n        f\"\"\"\n    SELECT\n        *\n    FROM {input_df}\n    ORDER BY b, a\n    \"\"\"\n    )\n    df_expected = df.sort_values([\"b\", \"a\"], ascending=[True, True])\n\n    assert_eq(df_result, df_expected, check_index=False)\n\n\n@pytest.mark.parametrize(\n    \"input_table_1\",\n    [\"user_table_1\", pytest.param(\"gpu_user_table_1\", marks=pytest.mark.gpu)],\n)\ndef test_sort_by_alias(c, input_table_1, request):\n    user_table_1 = request.getfixturevalue(input_table_1)\n\n    df_result = c.sql(\n        f\"\"\"\n    SELECT\n        b AS my_column\n    FROM {input_table_1}\n    ORDER BY my_column, user_id DESC\n    \"\"\"\n    ).rename(columns={\"my_column\": \"b\"})\n    df_expected = user_table_1.sort_values([\"b\", \"user_id\"], ascending=[True, False])[\n        [\"b\"]\n    ]\n\n    assert_eq(df_result, df_expected, check_index=False)\n\n\n@pytest.mark.parametrize(\"gpu\", [False, pytest.param(True, marks=pytest.mark.gpu)])\ndef test_sort_with_nan(gpu):\n    c = Context()\n    df = pd.DataFrame(\n        {\"a\": [1, 2, float(\"nan\"), 2], \"b\": [4, float(\"nan\"), 5, float(\"inf\")]}\n    )\n    c.create_table(\"df\", df, gpu=gpu)\n\n    df_result = c.sql(\"SELECT * FROM df ORDER BY a\")\n    assert_eq(\n        df_result,\n        pd.DataFrame(\n            {\"a\": [1, 2, 2, float(\"nan\")], \"b\": [4, float(\"nan\"), float(\"inf\"), 5]}\n        ),\n        check_index=False,\n    )\n\n    df_result = c.sql(\"SELECT * FROM df ORDER BY a NULLS FIRST\")\n    assert_eq(\n        df_result,\n        pd.DataFrame(\n            {\"a\": [float(\"nan\"), 1, 2, 2], \"b\": [5, 4, float(\"nan\"), float(\"inf\")]}\n        ),\n        check_index=False,\n    )\n\n    df_result = c.sql(\"SELECT * FROM df ORDER BY a NULLS LAST\")\n    assert_eq(\n        df_result,\n        pd.DataFrame(\n            {\"a\": [1, 2, 2, float(\"nan\")], \"b\": [4, float(\"nan\"), float(\"inf\"), 5]}\n        ),\n        check_index=False,\n    )\n\n    df_result = c.sql(\"SELECT * FROM df ORDER BY a ASC\")\n    assert_eq(\n        df_result,\n        pd.DataFrame(\n            {\"a\": [1, 2, 2, float(\"nan\")], \"b\": [4, float(\"nan\"), float(\"inf\"), 5]}\n        ),\n        check_index=False,\n    )\n\n    df_result = c.sql(\"SELECT * FROM df ORDER BY a ASC NULLS FIRST\")\n    assert_eq(\n        df_result,\n        pd.DataFrame(\n            {\"a\": [float(\"nan\"), 1, 2, 2], \"b\": [5, 4, float(\"nan\"), float(\"inf\")]}\n        ),\n        check_index=False,\n    )\n\n    df_result = c.sql(\"SELECT * FROM df ORDER BY a ASC NULLS LAST\")\n    assert_eq(\n        df_result,\n        pd.DataFrame(\n            {\"a\": [1, 2, 2, float(\"nan\")], \"b\": [4, float(\"nan\"), float(\"inf\"), 5]}\n        ),\n        check_index=False,\n    )\n\n    df_result = c.sql(\"SELECT * FROM df ORDER BY a DESC\")\n    assert_eq(\n        df_result,\n        pd.DataFrame(\n            {\"a\": [float(\"nan\"), 2, 2, 1], \"b\": [5, float(\"nan\"), float(\"inf\"), 4]}\n        ),\n        check_index=False,\n    )\n\n    df_result = c.sql(\"SELECT * FROM df ORDER BY a DESC NULLS FIRST\")\n    assert_eq(\n        df_result,\n        pd.DataFrame(\n            {\"a\": [float(\"nan\"), 2, 2, 1], \"b\": [5, float(\"nan\"), float(\"inf\"), 4]}\n        ),\n        check_index=False,\n    )\n\n    df_result = c.sql(\"SELECT * FROM df ORDER BY a DESC NULLS LAST\")\n    assert_eq(\n        df_result,\n        pd.DataFrame(\n            {\"a\": [2, 2, 1, float(\"nan\")], \"b\": [float(\"nan\"), float(\"inf\"), 4, 5]}\n        ),\n        check_index=False,\n    )\n\n\n@pytest.mark.parametrize(\"gpu\", [False, pytest.param(True, marks=pytest.mark.gpu)])\ndef test_sort_with_nan_more_columns(gpu):\n    c = Context()\n    df = pd.DataFrame(\n        {\n            \"a\": [1, 1, 2, 2, float(\"nan\"), float(\"nan\")],\n            \"b\": [1, 1, 2, float(\"nan\"), float(\"inf\"), 5],\n            \"c\": [1, float(\"nan\"), 3, 4, 5, 6],\n        }\n    )\n    c.create_table(\"df\", df, gpu=gpu)\n\n    df_result = c.sql(\n        \"SELECT * FROM df ORDER BY a ASC NULLS FIRST, b DESC NULLS LAST, c ASC NULLS FIRST\"\n    )\n    assert_eq(\n        df_result,\n        pd.DataFrame(\n            {\n                \"a\": [float(\"nan\"), float(\"nan\"), 1, 1, 2, 2],\n                \"b\": [float(\"inf\"), 5, 1, 1, 2, float(\"nan\")],\n                \"c\": [5, 6, float(\"nan\"), 1, 3, 4],\n            }\n        ),\n        check_index=False,\n    )\n\n    df_result = c.sql(\n        \"SELECT * FROM df ORDER BY a ASC NULLS LAST, b DESC NULLS FIRST, c DESC NULLS LAST\"\n    )\n    assert_eq(\n        df_result,\n        pd.DataFrame(\n            {\n                \"a\": [1, 1, 2, 2, float(\"nan\"), float(\"nan\")],\n                \"b\": [1, 1, float(\"nan\"), 2, float(\"inf\"), 5],\n                \"c\": [1, float(\"nan\"), 4, 3, 5, 6],\n            }\n        ),\n        check_index=False,\n    )\n\n    df_result = c.sql(\n        \"SELECT * FROM df ORDER BY a ASC NULLS FIRST, b DESC NULLS LAST, c DESC NULLS LAST\"\n    )\n    assert_eq(\n        df_result,\n        pd.DataFrame(\n            {\n                \"a\": [float(\"nan\"), float(\"nan\"), 1, 1, 2, 2],\n                \"b\": [float(\"inf\"), 5, 1, 1, 2, float(\"nan\")],\n                \"c\": [5, 6, 1, float(\"nan\"), 3, 4],\n            }\n        ),\n        check_index=False,\n    )\n\n\n@pytest.mark.parametrize(\"gpu\", [False, pytest.param(True, marks=pytest.mark.gpu)])\ndef test_sort_with_nan_many_partitions(gpu):\n    c = Context()\n    df = pd.DataFrame(\n        {\n            \"a\": [float(\"nan\"), 1] * 30,\n            \"b\": [1, 2, 3] * 20,\n        }\n    )\n    c.create_table(\"df\", dd.from_pandas(df, npartitions=10), gpu=gpu)\n\n    df_result = c.sql(\"SELECT * FROM df ORDER BY a NULLS FIRST, b ASC NULLS FIRST\")\n\n    assert_eq(\n        df_result,\n        pd.DataFrame(\n            {\n                \"a\": [float(\"nan\")] * 30 + [1] * 30,\n                \"b\": [1] * 10 + [2] * 10 + [3] * 10 + [1] * 10 + [2] * 10 + [3] * 10,\n            }\n        ),\n        check_index=False,\n    )\n\n    df = pd.DataFrame({\"a\": [float(\"nan\"), 1] * 30})\n    c.create_table(\"df\", dd.from_pandas(df, npartitions=10))\n\n    df_result = c.sql(\"SELECT * FROM df ORDER BY a\")\n\n    assert_eq(\n        df_result,\n        pd.DataFrame(\n            {\n                \"a\": [1] * 30 + [float(\"nan\")] * 30,\n            }\n        ),\n        check_index=False,\n    )\n\n\n@pytest.mark.parametrize(\"gpu\", [False, pytest.param(True, marks=pytest.mark.gpu)])\ndef test_sort_strings(c, gpu):\n    string_table = pd.DataFrame({\"a\": [\"zzhsd\", \"öfjdf\", \"baba\"]})\n    c.create_table(\"string_table\", string_table, gpu=gpu)\n\n    df_result = c.sql(\n        \"\"\"\n    SELECT\n        *\n    FROM string_table\n    ORDER BY a\n    \"\"\"\n    )\n\n    df_expected = string_table.sort_values([\"a\"], ascending=True)\n\n    assert_eq(df_result, df_expected, check_index=False)\n\n\n@pytest.mark.parametrize(\"gpu\", [False, pytest.param(True, marks=pytest.mark.gpu)])\ndef test_sort_not_allowed(c, gpu):\n    table_name = \"gpu_user_table_1\" if gpu else \"user_table_1\"\n\n    # Wrong column\n    with pytest.raises(Exception):\n        c.sql(f\"SELECT * FROM {table_name} ORDER BY 42\")\n\n\n@pytest.mark.parametrize(\n    \"input_table_1\",\n    [\"user_table_1\", pytest.param(\"gpu_user_table_1\", marks=pytest.mark.gpu)],\n)\ndef test_sort_by_old_alias(c, input_table_1, request):\n    user_table_1 = request.getfixturevalue(input_table_1)\n\n    df_result = c.sql(\n        f\"\"\"\n    SELECT\n        b AS my_column\n    FROM {input_table_1}\n    ORDER BY b, user_id DESC\n    \"\"\"\n    ).rename(columns={\"my_column\": \"b\"})\n    df_expected = user_table_1.sort_values([\"b\", \"user_id\"], ascending=[True, False])[\n        [\"b\"]\n    ]\n\n    assert_eq(df_result, df_expected, check_index=False)\n\n    df_result = c.sql(\n        f\"\"\"\n    SELECT\n        b*-1 AS my_column\n    FROM {input_table_1}\n    ORDER BY b, user_id DESC\n    \"\"\"\n    ).rename(columns={\"my_column\": \"b\"})\n    df_expected = user_table_1.sort_values([\"b\", \"user_id\"], ascending=[True, False])[\n        [\"b\"]\n    ]\n    df_expected[\"b\"] *= -1\n    assert_eq(df_result, df_expected, check_index=False)\n\n    df_result = c.sql(\n        f\"\"\"\n    SELECT\n        b*-1 AS my_column\n    FROM {input_table_1}\n    ORDER BY my_column, user_id DESC\n    \"\"\"\n    ).rename(columns={\"my_column\": \"b\"})\n    df_expected[\"b\"] *= -1\n    df_expected = user_table_1.sort_values([\"b\", \"user_id\"], ascending=[True, False])[\n        [\"b\"]\n    ]\n\n\ndef check_sort_topk(df, layer, contains=True):\n    if dd._dask_expr_enabled():\n        from dask_expr._reductions import NLargest, NSmallest\n\n        if layer == \"nsmallest\":\n            assert len(list(df.expr.find_operations(NSmallest))) == (\n                1 if contains else 0\n            )\n        elif layer == \"nlargest\":\n            assert len(list(df.expr.find_operations(NLargest))) == (\n                1 if contains else 0\n            )\n        else:\n            assert False\n    else:\n        assert (\n            any([layer in key for key in df.dask.layers.keys()])\n            if contains\n            else all([layer not in key for key in df.dask.layers.keys()])\n        )\n\n\n@pytest.mark.parametrize(\"gpu\", [False, pytest.param(True, marks=pytest.mark.gpu)])\ndef test_sort_topk(gpu):\n    c = Context()\n    df = pd.DataFrame(\n        {\n            \"a\": [float(\"nan\"), 1] * 30,\n            \"b\": [1, 2, 3] * 20,\n            \"c\": [\"a\", \"b\", \"c\"] * 20,\n        }\n    )\n    c.create_table(\"df\", dd.from_pandas(df, npartitions=10), gpu=gpu)\n\n    df_result = c.sql(\"\"\"SELECT * FROM df ORDER BY a LIMIT 10\"\"\")\n    check_sort_topk(df_result, \"nsmallest\", True)\n    assert_eq(\n        df_result,\n        pd.DataFrame(\n            {\n                \"a\": [1.0] * 10,\n                \"b\": ([2, 1, 3] * 4)[:10],\n                \"c\": ([\"b\", \"a\", \"c\"] * 4)[:10],\n            }\n        ),\n        check_index=False,\n    )\n\n    df_result = c.sql(\"\"\"SELECT * FROM df ORDER BY a, b LIMIT 10\"\"\")\n    check_sort_topk(df_result, \"nsmallest\", True)\n    assert_eq(\n        df_result,\n        pd.DataFrame({\"a\": [1.0] * 10, \"b\": [1] * 10, \"c\": [\"a\"] * 10}),\n        check_index=False,\n    )\n\n    df_result = c.sql(\n        \"\"\"SELECT * FROM df ORDER BY a DESC NULLS LAST, b DESC NULLS LAST LIMIT 10\"\"\"\n    )\n    check_sort_topk(df_result, \"nlargest\", True)\n    assert_eq(\n        df_result,\n        pd.DataFrame({\"a\": [1.0] * 10, \"b\": [3] * 10, \"c\": [\"c\"] * 10}),\n        check_index=False,\n    )\n\n    # String column nlargest/smallest not supported for pandas\n    df_result = c.sql(\"\"\"SELECT * FROM df ORDER BY c LIMIT 10\"\"\")\n    if not gpu:\n        check_sort_topk(df_result, \"nsmallest\", False)\n        check_sort_topk(df_result, \"nlargest\", False)\n    else:\n        assert_eq(\n            df_result,\n            pd.DataFrame({\"a\": [float(\"nan\"), 1] * 5, \"b\": [1] * 10, \"c\": [\"a\"] * 10}),\n            check_index=False,\n        )\n\n    # Assert that the optimization isn't applied when there is any nulls first\n    df_result = c.sql(\n        \"\"\"SELECT * FROM df ORDER BY a DESC, b DESC NULLS LAST LIMIT 10\"\"\"\n    )\n    check_sort_topk(df_result, \"nlargest\", False)\n    check_sort_topk(df_result, \"nsmallest\", False)\n\n    # Assert optimization isn't applied for mixed asc + desc sort\n    df_result = c.sql(\"\"\"SELECT * FROM df ORDER BY a, b DESC NULLS LAST LIMIT 10\"\"\")\n    check_sort_topk(df_result, \"nlargest\", False)\n    check_sort_topk(df_result, \"nsmallest\", False)\n\n    # Assert optimization isn't applied when the number of requested elements\n    # exceed topk-nelem-limit config value\n    # Default topk-nelem-limit is 1M and 334k*3columns takes it above this limit\n    df_result = c.sql(\"\"\"SELECT * FROM df ORDER BY a, b LIMIT 333334\"\"\")\n    check_sort_topk(df_result, \"nlargest\", False)\n    check_sort_topk(df_result, \"nsmallest\", False)\n\n    df_result = c.sql(\n        \"\"\"SELECT * FROM df ORDER BY a, b LIMIT 10\"\"\",\n        config_options={\"sql.sort.topk-nelem-limit\": 29},\n    )\n    check_sort_topk(df_result, \"nlargest\", False)\n    check_sort_topk(df_result, \"nsmallest\", False)\n"
  },
  {
    "path": "tests/integration/test_sqlite.py",
    "content": "import sqlite3\n\nimport pytest\n\n\n@pytest.fixture(scope=\"session\")\ndef engine():\n    yield sqlite3.connect(\":memory:\")\n\n\ndef test_select(assert_query_gives_same_result):\n    assert_query_gives_same_result(\n        \"\"\"\n        SELECT * FROM df1\n    \"\"\"\n    )\n\n    assert_query_gives_same_result(\n        \"\"\"\n        SELECT df1.user_id + 5 AS user_id, 2 * df1.b AS b FROM df1\n    \"\"\"\n    )\n\n    assert_query_gives_same_result(\n        \"\"\"\n        SELECT df2.user_id, df2.d FROM df2\n    \"\"\"\n    )\n\n    assert_query_gives_same_result(\n        \"\"\"\n        SELECT 1 AS I, -5.34344 AS F, 'öäll' AS S\n    \"\"\"\n    )\n\n    assert_query_gives_same_result(\n        \"\"\"\n        SELECT CASE WHEN user_id = 3 THEN 4 ELSE user_id END FROM df2\n    \"\"\"\n    )\n\n\ndef test_join(assert_query_gives_same_result):\n    assert_query_gives_same_result(\n        \"\"\"\n        SELECT\n            df1.user_id, df1.a, df1.b,\n            df2.user_id AS user_id_2, df2.c, df2.d\n        FROM df1\n        JOIN df2 ON df1.user_id = df2.user_id\n    \"\"\",\n        [\"user_id\", \"a\", \"b\", \"user_id_2\", \"c\", \"d\"],\n    )\n\n\ndef test_sort(assert_query_gives_same_result):\n    assert_query_gives_same_result(\n        \"\"\"\n        SELECT\n            user_id, b\n        FROM df1\n        ORDER BY b NULLS FIRST, user_id DESC NULLS FIRST\n    \"\"\"\n    )\n\n    assert_query_gives_same_result(\n        \"\"\"\n        SELECT\n            c, d\n        FROM df2\n        WHERE d IS NOT NULL -- sqlite sorts the NaNs in a strange way\n        ORDER BY c, d, user_id\n    \"\"\"\n    )\n\n\ndef test_limit(assert_query_gives_same_result):\n    assert_query_gives_same_result(\n        \"\"\"\n        SELECT\n            c, d\n        FROM df2\n        WHERE d IS NOT NULL -- sqlite sorts the NaNs in a strange way\n        ORDER BY c, d, user_id\n        LIMIT 10 OFFSET 20\n    \"\"\"\n    )\n\n    assert_query_gives_same_result(\n        \"\"\"\n        SELECT\n            c, d\n        FROM df2\n        WHERE d IS NOT NULL -- sqlite sorts the NaNs in a strange way\n        ORDER BY c, d, user_id\n        LIMIT 200\n    \"\"\"\n    )\n\n\ndef test_groupby(assert_query_gives_same_result):\n    assert_query_gives_same_result(\n        \"\"\"\n        SELECT\n            d, SUM(c), SUM(user_id)\n        FROM df2\n        WHERE d IS NOT NULL -- dask behaves differently on NaNs in groupbys\n        GROUP BY d\n        ORDER BY SUM(c)\n        LIMIT 10\n    \"\"\"\n    )\n\n    assert_query_gives_same_result(\n        \"\"\"\n        SELECT AVG(c)\n        FROM df2\n    \"\"\"\n    )\n\n\ndef test_calc(assert_query_gives_same_result):\n    assert_query_gives_same_result(\n        \"\"\"\n        SELECT\n            a + b,\n            a*b,\n            a*5,\n            a / user_id,\n            user_id / a\n        FROM df1\n    \"\"\"\n    )\n\n\ndef test_filter(assert_query_gives_same_result):\n    assert_query_gives_same_result(\n        \"\"\"\n        SELECT\n            a\n        FROM df1\n        WHERE\n            user_id = 3 AND a > 0.5\n    \"\"\"\n    )\n\n    assert_query_gives_same_result(\n        \"\"\"\n        SELECT\n            d\n        FROM df2\n        WHERE\n            d NOT LIKE '%c'\n    \"\"\"\n    )\n\n    assert_query_gives_same_result(\n        \"\"\"\n        SELECT\n            d\n        FROM df2\n        WHERE\n            d = 'a'\n    \"\"\"\n    )\n\n    assert_query_gives_same_result(\n        \"\"\"\n        SELECT\n            *\n        FROM df1\n        WHERE\n            1 < a AND a < 5\n    \"\"\"\n    )\n\n    assert_query_gives_same_result(\n        \"\"\"\n        SELECT\n            *\n        FROM df1\n        WHERE\n            a < 5 AND b < 5\n    \"\"\"\n    )\n\n    assert_query_gives_same_result(\n        \"\"\"\n        SELECT\n            *\n        FROM df1\n        WHERE\n            a + b > 5\n    \"\"\"\n    )\n"
  },
  {
    "path": "tests/integration/test_union.py",
    "content": "import pandas as pd\n\nfrom tests.utils import assert_eq\n\n\ndef test_union_not_all(c, df):\n    result_df = c.sql(\n        \"\"\"\n        SELECT * FROM df\n        UNION\n        SELECT * FROM df\n        UNION\n        SELECT * FROM df\n        \"\"\"\n    )\n\n    assert_eq(result_df, df, check_index=False)\n\n\ndef test_union_all(c, df):\n    result_df = c.sql(\n        \"\"\"\n        SELECT * FROM df\n        UNION ALL\n        SELECT * FROM df\n        UNION ALL\n        SELECT * FROM df\n        \"\"\"\n    )\n    expected_df = pd.concat([df, df, df], ignore_index=True)\n\n    assert_eq(result_df, expected_df, check_index=False)\n\n\ndef test_union_mixed(c, df, long_table):\n    result_df = c.sql(\n        \"\"\"\n        SELECT a AS \"I\", b as \"II\" FROM df\n        UNION ALL\n        SELECT a as \"I\", a as \"II\" FROM long_table\n        \"\"\"\n    )\n    long_table = long_table.rename(columns={\"a\": \"I\"})\n    long_table[\"II\"] = long_table[\"I\"]\n    expected_df = pd.concat(\n        [df.rename(columns={\"a\": \"I\", \"b\": \"II\"}), long_table],\n        ignore_index=True,\n    )\n\n    assert_eq(result_df, expected_df, check_index=False)\n"
  },
  {
    "path": "tests/unit/__init__.py",
    "content": ""
  },
  {
    "path": "tests/unit/test_call.py",
    "content": "import datetime\nimport operator\nfrom unittest.mock import MagicMock\n\nimport dask.dataframe as dd\nimport numpy as np\nimport pandas as pd\n\nimport dask_sql.physical.rex.core.call as call\nfrom tests.utils import assert_eq\n\ndf1 = dd.from_pandas(pd.DataFrame({\"a\": [1, 2, 3]}), npartitions=1)\ndf2 = dd.from_pandas(pd.DataFrame({\"a\": [3, 2, 1]}), npartitions=1)\ndf3 = dd.from_pandas(\n    pd.DataFrame({\"a\": [True, pd.NA, False]}, dtype=\"boolean\"), npartitions=1\n)\nops_mapping = call.RexCallPlugin.OPERATION_MAPPING\n\n\ndef test_operation():\n    operator = MagicMock()\n    operator.return_value = \"test\"\n\n    op = call.Operation(operator)\n\n    assert op(\"input\") == \"test\"\n    operator.assert_called_once_with(\"input\")\n\n\ndef test_reduce():\n    op = call.ReduceOperation(operator.add)\n\n    assert op(1, 2, 3) == 6\n\n\ndef test_case():\n    op = call.CaseOperation()\n\n    assert_eq(op(df1.a > 2, df1.a, df2.a), pd.Series([3, 2, 3]), check_names=False)\n\n    assert_eq(op(df1.a > 2, 99, df2.a), pd.Series([3, 2, 99]), check_names=False)\n\n    assert_eq(op(df1.a > 2, 99, -1), pd.Series([-1, -1, 99]), check_names=False)\n\n    assert_eq(op(df1.a > 2, df1.a, -1), pd.Series([-1, -1, 3]), check_names=False)\n\n    assert op(True, 1, 2) == 1\n    assert op(False, 1, 2) == 2\n\n\ndef test_is_true():\n    op = call.IsTrueOperation()\n\n    assert_eq(\n        op(df1.a > 2),\n        pd.Series([False, False, True]),\n        check_names=False,\n        check_dtype=False,\n    )\n    assert_eq(\n        op(df3.a),\n        pd.Series([True, False, False]),\n        check_names=False,\n        check_dtype=False,\n    )\n\n    assert op(1)\n    assert not op(0)\n    assert not op(None)\n    assert not op(np.NaN)\n    assert not op(pd.NA)\n\n\ndef test_is_false():\n    op = call.IsFalseOperation()\n\n    assert_eq(\n        op(df1.a > 2),\n        pd.Series([True, True, False]),\n        check_names=False,\n        check_dtype=False,\n    )\n    assert_eq(\n        op(df3.a),\n        pd.Series([False, False, True]),\n        check_names=False,\n        check_dtype=False,\n    )\n\n    assert not op(1)\n    assert op(0)\n    assert not op(None)\n    assert not op(np.NaN)\n    assert not op(pd.NA)\n\n\ndef test_like():\n    op = call.LikeOperation()\n\n    assert op(\"a string\", r\"%a%\")\n    assert op(\"another string\", r\"a%\")\n    assert not op(\"another string\", r\"s%\")\n\n    op = call.SimilarOperation()\n    assert op(\"normal\", r\"n[a-z]rm_l\")\n    assert not op(\"not normal\", r\"n[a-z]rm_l\")\n\n\ndef test_not():\n    op = call.NotOperation()\n\n    assert op(False)\n    assert not op(True)\n\n    assert not op(3)\n\n\ndef test_nan():\n    op = call.IsNullOperation()\n\n    assert op(None)\n    assert op(np.NaN)\n    assert op(pd.NA)\n    assert_eq(op(pd.Series([\"a\", None, \"c\"])), pd.Series([False, True, False]))\n    assert_eq(\n        op(pd.Series([3, 2, np.NaN, pd.NA])), pd.Series([False, False, True, True])\n    )\n\n\ndef test_simple_ops():\n    assert_eq(\n        ops_mapping[\"and\"](df1.a >= 2, df2.a >= 2),\n        pd.Series([False, True, False]),\n        check_names=False,\n    )\n\n    assert_eq(\n        ops_mapping[\"or\"](df1.a >= 2, df2.a >= 2),\n        pd.Series([True, True, True]),\n        check_names=False,\n    )\n\n    assert_eq(\n        ops_mapping[\">=\"](df1.a, df2.a),\n        pd.Series([False, True, True]),\n        check_names=False,\n    )\n\n    assert_eq(\n        ops_mapping[\"+\"](df1.a, df2.a, df1.a),\n        pd.Series([5, 6, 7]),\n        check_names=False,\n    )\n\n\ndef test_math_operations():\n    assert_eq(\n        ops_mapping[\"abs\"](-df1.a),\n        pd.Series([1, 2, 3]),\n        check_names=False,\n    )\n    assert_eq(\n        ops_mapping[\"round\"](df1.a),\n        pd.Series([1, 2, 3]),\n        check_names=False,\n    )\n    assert_eq(\n        ops_mapping[\"floor\"](df1.a),\n        pd.Series([1.0, 2.0, 3.0]),\n        check_names=False,\n    )\n\n    assert ops_mapping[\"abs\"](-5) == 5\n    assert ops_mapping[\"round\"](1.234, 2) == 1.23\n    assert ops_mapping[\"floor\"](1.234) == 1\n\n\ndef test_string_operations():\n    a = \"a normal string\"\n    assert ops_mapping[\"characterlength\"](a) == 15\n    assert ops_mapping[\"upper\"](a) == \"A NORMAL STRING\"\n    assert ops_mapping[\"lower\"](a) == \"a normal string\"\n    assert ops_mapping[\"position\"](\"a\", a, 4) == 7\n    assert ops_mapping[\"position\"](\"ZL\", a) == 0\n    assert ops_mapping[\"trim\"](a, \"a\") == \" normal string\"\n    assert ops_mapping[\"btrim\"](a, \"a\") == \" normal string\"\n    assert ops_mapping[\"ltrim\"](a, \"a\") == \" normal string\"\n    assert ops_mapping[\"rtrim\"](a, \"a\") == \"a normal string\"\n    assert ops_mapping[\"overlay\"](a, \"XXX\", 2) == \"aXXXrmal string\"\n    assert ops_mapping[\"overlay\"](a, \"XXX\", 2, 4) == \"aXXXmal string\"\n    assert ops_mapping[\"overlay\"](a, \"XXX\", 2, 1) == \"aXXXnormal string\"\n    assert ops_mapping[\"substring\"](a, -1) == \"a normal string\"\n    assert ops_mapping[\"substring\"](a, 10) == \"string\"\n    assert ops_mapping[\"substring\"](a, 2) == \" normal string\"\n    assert ops_mapping[\"substring\"](a, 2, 2) == \" n\"\n    assert ops_mapping[\"initcap\"](a) == \"A Normal String\"\n    assert ops_mapping[\"replace\"](a, \"nor\", \"\") == \"a mal string\"\n    assert ops_mapping[\"replace\"](a, \"normal\", \"new\") == \"a new string\"\n    assert ops_mapping[\"replace\"](\"hello\", \"\", \"w\") == \"whwewlwlwow\"\n\n\ndef test_dates():\n    op = call.ExtractOperation()\n\n    date = datetime.datetime(2021, 10, 3, 15, 53, 42, 47)\n    assert int(op(\"CENTURY\", date)) == 20\n    assert op(\"DAY\", date) == 3\n    assert int(op(\"DECADE\", date)) == 202\n    assert op(\"DOW\", date) == 0\n    assert op(\"DOY\", date) == 276\n    assert op(\"HOUR\", date) == 15\n    assert op(\"MICROSECOND\", date) == 47\n    assert op(\"MILLENNIUM\", date) == 2\n    assert op(\"MILLISECOND\", date) == 47000\n    assert op(\"MINUTE\", date) == 53\n    assert op(\"MONTH\", date) == 10\n    assert op(\"QUARTER\", date) == 4\n    assert op(\"SECOND\", date) == 42\n    assert op(\"WEEK\", date) == 39\n    assert op(\"YEAR\", date) == 2021\n    assert op(\"DATE\", date) == datetime.date(2021, 10, 3)\n\n    ceil_op = call.CeilFloorOperation(\"ceil\")\n    floor_op = call.CeilFloorOperation(\"floor\")\n\n    assert ceil_op(date, \"DAY\") == datetime.datetime(2021, 10, 4)\n    assert ceil_op(date, \"HOUR\") == datetime.datetime(2021, 10, 3, 16)\n    assert ceil_op(date, \"MINUTE\") == datetime.datetime(2021, 10, 3, 15, 54)\n    assert ceil_op(date, \"SECOND\") == datetime.datetime(2021, 10, 3, 15, 53, 43)\n    assert ceil_op(date, \"MILLISECOND\") == datetime.datetime(\n        2021, 10, 3, 15, 53, 42, 1000\n    )\n\n    assert floor_op(date, \"DAY\") == datetime.datetime(2021, 10, 3)\n    assert floor_op(date, \"HOUR\") == datetime.datetime(2021, 10, 3, 15)\n    assert floor_op(date, \"MINUTE\") == datetime.datetime(2021, 10, 3, 15, 53)\n    assert floor_op(date, \"SECOND\") == datetime.datetime(2021, 10, 3, 15, 53, 42)\n    assert floor_op(date, \"MILLISECOND\") == datetime.datetime(2021, 10, 3, 15, 53, 42)\n"
  },
  {
    "path": "tests/unit/test_config.py",
    "content": "import os\nimport sys\nfrom unittest import mock\n\nimport dask.dataframe as dd\nimport pandas as pd\nimport pytest\nimport yaml\nfrom dask import config as dask_config\n\n# Required to instantiate default sql config\nimport dask_sql  # noqa: F401\nfrom dask_sql import Context\nfrom tests.utils import skipif_dask_expr_enabled\n\n\ndef test_custom_yaml(tmpdir):\n    custom_config = {}\n    custom_config[\"sql\"] = dask_config.get(\"sql\")\n    custom_config[\"sql\"][\"aggregate\"][\"split_out\"] = 16\n    custom_config[\"sql\"][\"foo\"] = {\"bar\": [1, 2, 3], \"baz\": None}\n\n    with open(os.path.join(tmpdir, \"custom-sql.yaml\"), mode=\"w\") as f:\n        yaml.dump(custom_config, f)\n\n    dask_config.refresh(\n        paths=[tmpdir]\n    )  # Refresh config to read from updated environment\n    assert custom_config[\"sql\"] == dask_config.get(\"sql\")\n    dask_config.refresh()\n\n\ndef test_env_variable():\n    with mock.patch.dict(\"os.environ\", {\"DASK_SQL__AGGREGATE__SPLIT_OUT\": \"200\"}):\n        dask_config.refresh()\n        assert dask_config.get(\"sql.aggregate.split-out\") == 200\n    dask_config.refresh()\n\n\ndef test_default_config():\n    config_fn = os.path.join(os.path.dirname(__file__), \"../../dask_sql\", \"sql.yaml\")\n    with open(config_fn) as f:\n        default_config = yaml.safe_load(f)\n    assert \"sql\" in default_config\n    assert default_config[\"sql\"] == dask_config.get(\"sql\")\n\n\ndef test_schema():\n    jsonschema = pytest.importorskip(\"jsonschema\")\n\n    config_fn = os.path.join(os.path.dirname(__file__), \"../../dask_sql\", \"sql.yaml\")\n    schema_fn = os.path.join(\n        os.path.dirname(__file__), \"../../dask_sql\", \"sql-schema.yaml\"\n    )\n\n    with open(config_fn) as f:\n        config = yaml.safe_load(f)\n\n    with open(schema_fn) as f:\n        schema = yaml.safe_load(f)\n\n    jsonschema.validate(config, schema)\n\n\ndef test_schema_is_complete():\n    config_fn = os.path.join(os.path.dirname(__file__), \"../../dask_sql\", \"sql.yaml\")\n    schema_fn = os.path.join(\n        os.path.dirname(__file__), \"../../dask_sql\", \"sql-schema.yaml\"\n    )\n\n    with open(config_fn) as f:\n        config = yaml.safe_load(f)\n\n    with open(schema_fn) as f:\n        schema = yaml.safe_load(f)\n\n    def test_matches(c, s):\n        for k, v in c.items():\n            if list(c) != list(s[\"properties\"]):\n                raise ValueError(\n                    \"\\nThe sql.yaml and sql-schema.yaml files are not in sync.\\n\"\n                    \"This usually happens when we add a new configuration value,\\n\"\n                    \"but don't add the schema of that value to the dask-schema.yaml file\\n\"\n                    \"Please modify these files to include the missing values: \\n\\n\"\n                    \"    sql.yaml:        {}\\n\"\n                    \"    sql-schema.yaml: {}\\n\\n\"\n                    \"Examples in these files should be a good start, \\n\"\n                    \"even if you are not familiar with the jsonschema spec\".format(\n                        sorted(c), sorted(s[\"properties\"])\n                    )\n                )\n            if isinstance(v, dict):\n                test_matches(c[k], s[\"properties\"][k])\n\n    test_matches(config, schema)\n\n\ndef test_dask_setconfig():\n    dask_config.set({\"sql.foo.bar\": 1})\n    with dask_config.set({\"sql.foo.baz\": \"2\"}):\n        assert dask_config.get(\"sql.foo\") == {\"bar\": 1, \"baz\": \"2\"}\n    assert dask_config.get(\"sql.foo\") == {\"bar\": 1}\n    dask_config.refresh()\n\n\n@pytest.mark.skipif(\n    sys.version_info < (3, 10),\n    reason=\"Writing and reading the Dask DataFrame causes a ProtocolError\",\n)\n@skipif_dask_expr_enabled(\"dynamic partition pruning not yet supported with dask-expr\")\ndef test_dynamic_partition_pruning(tmpdir):\n    c = Context()\n\n    df1 = pd.DataFrame(\n        {\n            \"x\": [1, 2, 3],\n            \"z\": [7, 8, 9],\n        },\n    )\n    dd.from_pandas(df1, npartitions=3).to_parquet(os.path.join(tmpdir, \"df1\"))\n    df1 = dd.read_parquet(os.path.join(tmpdir, \"df1\"))\n    c.create_table(\"df1\", df1)\n\n    df2 = pd.DataFrame(\n        {\n            \"x\": [1, 2, 3] * 1000,\n            \"y\": [4, 5, 6] * 1000,\n        },\n    )\n    dd.from_pandas(df2, npartitions=3).to_parquet(os.path.join(tmpdir, \"df2\"))\n    df2 = dd.read_parquet(os.path.join(tmpdir, \"df2\"))\n    c.create_table(\"df2\", df2)\n\n    query = \"SELECT * FROM df1, df2 WHERE df1.x = df2.x AND df1.z=7\"\n    inlist_expr = \"df2.x IN ([Int64(1)])\"\n\n    # Default value is False\n    dask_config.set({\"sql.optimizer.verbose\": True})\n\n    # When DPP is turned off, the explain output will not contain the INLIST expression\n    dask_config.set({\"sql.dynamic_partition_pruning\": False})\n    explain_string = c.explain(query)\n    assert inlist_expr not in explain_string\n\n    # When DPP is turned on but sql.optimizer.verbose is off, the explain output will not contain the\n    # INLIST expression\n    dask_config.set({\"sql.dynamic_partition_pruning\": True})\n    dask_config.set({\"sql.optimizer.verbose\": False})\n    explain_string = c.explain(query)\n    assert inlist_expr not in explain_string\n\n    # When both DPP and sql.optimizer.verbose are turned on, the explain output will contain the INLIST\n    # expression\n    dask_config.set({\"sql.dynamic_partition_pruning\": True})\n    dask_config.set({\"sql.optimizer.verbose\": True})\n    explain_string = c.explain(query)\n    assert inlist_expr in explain_string\n\n\n@skipif_dask_expr_enabled(\"dynamic partition pruning not yet supported with dask-expr\")\ndef test_dpp_single_file_parquet(tmpdir):\n    c = Context()\n\n    dask_config.set({\"sql.dynamic_partition_pruning\": True})\n    dask_config.set({\"sql.optimizer.verbose\": True})\n\n    df1 = pd.DataFrame(\n        {\n            \"x\": [1, 2, 3],\n            \"z\": [7, 8, 9],\n        },\n    )\n    dd.from_pandas(df1, npartitions=1).to_parquet(\n        os.path.join(tmpdir, \"df1_single_file\")\n    )\n    df1 = dd.read_parquet(os.path.join(tmpdir, \"df1_single_file/part.0.parquet\"))\n    c.create_table(\"df1\", df1)\n\n    df2 = pd.DataFrame(\n        {\n            \"x\": [1, 2, 3] * 1000,\n            \"y\": [4, 5, 6] * 1000,\n        },\n    )\n    dd.from_pandas(df2, npartitions=3).to_parquet(os.path.join(tmpdir, \"df2\"))\n    df2 = dd.read_parquet(os.path.join(tmpdir, \"df2\"))\n    c.create_table(\"df2\", df2)\n\n    query = \"SELECT * FROM df1, df2 WHERE df1.x = df2.x AND df1.z=7\"\n    inlist_expr = \"df2.x IN ([Int64(1)])\"\n\n    explain_string = c.explain(query)\n    assert inlist_expr in explain_string\n"
  },
  {
    "path": "tests/unit/test_context.py",
    "content": "import os\nimport sys\n\nimport dask.dataframe as dd\nimport pandas as pd\nimport pytest\n\nfrom dask_sql import Context\nfrom tests.utils import assert_eq\n\ntry:\n    import cudf\n    import dask_cudf\nexcept ImportError:\n    cudf = None\n    dask_cudf = None\n\n# default integer type varies by platform\nDEFAULT_INT_TYPE = \"INTEGER\" if sys.platform == \"win32\" else \"BIGINT\"\n\n\n@pytest.mark.parametrize(\"gpu\", [False, pytest.param(True, marks=pytest.mark.gpu)])\ndef test_add_remove_tables(gpu):\n    c = Context()\n\n    data_frame = dd.from_pandas(pd.DataFrame(), npartitions=1)\n\n    c.create_table(\"table\", data_frame, gpu=gpu)\n    assert \"table\" in c.schema[c.schema_name].tables\n\n    c.drop_table(\"table\")\n    assert \"table\" not in c.schema[c.schema_name].tables\n\n    with pytest.raises(KeyError):\n        c.drop_table(\"table\")\n\n    c.create_table(\"table\", [data_frame], gpu=gpu)\n    assert \"table\" in c.schema[c.schema_name].tables\n\n\n@pytest.mark.parametrize(\n    \"gpu\",\n    [\n        False,\n        pytest.param(\n            True,\n            marks=pytest.mark.gpu,\n        ),\n    ],\n)\ndef test_sql(gpu):\n    c = Context()\n\n    data_frame = dd.from_pandas(pd.DataFrame({\"a\": [1, 2, 3]}), npartitions=1)\n    c.create_table(\"df\", data_frame, gpu=gpu)\n\n    result = c.sql(\"SELECT * FROM df\")\n    assert isinstance(result, dd.DataFrame)\n    assert_eq(result, data_frame)\n\n    result = c.sql(\"SELECT * FROM df\", return_futures=False)\n    assert not isinstance(result, dd.DataFrame)\n    assert_eq(result, data_frame)\n\n    result = c.sql(\n        \"SELECT * FROM other_df\", dataframes={\"other_df\": data_frame}, gpu=gpu\n    )\n    assert isinstance(result, dd.DataFrame)\n    assert_eq(result, data_frame)\n\n\n@pytest.mark.parametrize(\n    \"gpu\",\n    [\n        False,\n        pytest.param(\n            True,\n            marks=pytest.mark.gpu,\n        ),\n    ],\n)\ndef test_input_types(temporary_data_file, gpu):\n    c = Context()\n    df = pd.DataFrame({\"a\": [1, 2, 3]})\n\n    def assert_correct_output(gpu):\n        result = c.sql(\"SELECT * FROM df\")\n        assert isinstance(result, dd.DataFrame if not gpu else dask_cudf.DataFrame)\n        assert_eq(result, df)\n\n    c.create_table(\"df\", df, gpu=gpu)\n    assert_correct_output(gpu=gpu)\n\n    c.create_table(\"df\", dd.from_pandas(df, npartitions=1), gpu=gpu)\n    assert_correct_output(gpu=gpu)\n\n    df.to_csv(temporary_data_file, index=False)\n    c.create_table(\"df\", temporary_data_file, gpu=gpu)\n    assert_correct_output(gpu=gpu)\n\n    df.to_csv(temporary_data_file, index=False)\n    c.create_table(\"df\", temporary_data_file, format=\"csv\", gpu=gpu)\n    assert_correct_output(gpu=gpu)\n\n    df.to_parquet(temporary_data_file, index=False)\n    c.create_table(\"df\", temporary_data_file, format=\"parquet\", gpu=gpu)\n    assert_correct_output(gpu=gpu)\n\n    with pytest.raises(AttributeError):\n        c.create_table(\"df\", temporary_data_file, format=\"unknown\", gpu=gpu)\n\n    strangeThing = object()\n\n    with pytest.raises(ValueError):\n        c.create_table(\"df\", strangeThing, gpu=gpu)\n\n\n@pytest.mark.parametrize(\n    \"gpu\",\n    [\n        False,\n        pytest.param(True, marks=pytest.mark.gpu),\n    ],\n)\ndef test_tables_from_stack(gpu):\n    c = Context()\n\n    assert not c._get_tables_from_stack()\n\n    df = pd.DataFrame() if not gpu else cudf.DataFrame()\n\n    assert \"df\" in c._get_tables_from_stack()\n\n    def f(gpu):\n        df2 = pd.DataFrame() if not gpu else cudf.DataFrame()\n\n        assert \"df\" in c._get_tables_from_stack()\n        assert \"df2\" in c._get_tables_from_stack()\n\n    f(gpu=gpu)\n\n    def g(gpu=gpu):\n        df = pd.DataFrame({\"a\": [1]}) if not gpu else cudf.DataFrame({\"a\": [1]})\n\n        assert \"df\" in c._get_tables_from_stack()\n        assert c._get_tables_from_stack()[\"df\"].columns == [\"a\"]\n\n    g(gpu=gpu)\n\n\ndef test_function_adding():\n    c = Context()\n\n    assert not c.schema[c.schema_name].function_lists\n    assert not c.schema[c.schema_name].functions\n\n    f = lambda x: x\n    c.register_function(f, \"f\", [(\"x\", int)], float)\n\n    assert \"f\" in c.schema[c.schema_name].functions\n    assert c.schema[c.schema_name].functions[\"f\"].func == f\n    assert len(c.schema[c.schema_name].function_lists) == 2\n    assert c.schema[c.schema_name].function_lists[0].name == \"F\"\n    assert c.schema[c.schema_name].function_lists[0].parameters[0][0] == \"x\"\n    assert (\n        str(c.schema[c.schema_name].function_lists[0].parameters[0][1])\n        == DEFAULT_INT_TYPE\n    )\n    assert str(c.schema[c.schema_name].function_lists[0].return_type) == \"DOUBLE\"\n    assert not c.schema[c.schema_name].function_lists[0].aggregation\n    assert c.schema[c.schema_name].function_lists[1].name == \"f\"\n    assert c.schema[c.schema_name].function_lists[1].parameters[0][0] == \"x\"\n    assert (\n        str(c.schema[c.schema_name].function_lists[1].parameters[0][1])\n        == DEFAULT_INT_TYPE\n    )\n    assert str(c.schema[c.schema_name].function_lists[1].return_type) == \"DOUBLE\"\n    assert not c.schema[c.schema_name].function_lists[1].aggregation\n\n    # Without replacement\n    c.register_function(f, \"f\", [(\"x\", float)], int, replace=False)\n\n    assert \"f\" in c.schema[c.schema_name].functions\n    assert c.schema[c.schema_name].functions[\"f\"].func == f\n    assert len(c.schema[c.schema_name].function_lists) == 4\n    assert c.schema[c.schema_name].function_lists[2].name == \"F\"\n    assert c.schema[c.schema_name].function_lists[2].parameters[0][0] == \"x\"\n    assert str(c.schema[c.schema_name].function_lists[2].parameters[0][1]) == \"DOUBLE\"\n    assert (\n        str(c.schema[c.schema_name].function_lists[2].return_type) == DEFAULT_INT_TYPE\n    )\n    assert not c.schema[c.schema_name].function_lists[2].aggregation\n    assert c.schema[c.schema_name].function_lists[3].name == \"f\"\n    assert c.schema[c.schema_name].function_lists[3].parameters[0][0] == \"x\"\n    assert str(c.schema[c.schema_name].function_lists[3].parameters[0][1]) == \"DOUBLE\"\n    assert (\n        str(c.schema[c.schema_name].function_lists[3].return_type) == DEFAULT_INT_TYPE\n    )\n    assert not c.schema[c.schema_name].function_lists[3].aggregation\n\n    # With replacement\n    f = lambda x: x + 1\n    c.register_function(f, \"f\", [(\"x\", str)], str, replace=True)\n\n    assert \"f\" in c.schema[c.schema_name].functions\n    assert c.schema[c.schema_name].functions[\"f\"].func == f\n    assert len(c.schema[c.schema_name].function_lists) == 2\n    assert c.schema[c.schema_name].function_lists[0].name == \"F\"\n    assert c.schema[c.schema_name].function_lists[0].parameters[0][0] == \"x\"\n    assert str(c.schema[c.schema_name].function_lists[0].parameters[0][1]) == \"VARCHAR\"\n    assert str(c.schema[c.schema_name].function_lists[0].return_type) == \"VARCHAR\"\n    assert not c.schema[c.schema_name].function_lists[0].aggregation\n    assert c.schema[c.schema_name].function_lists[1].name == \"f\"\n    assert c.schema[c.schema_name].function_lists[1].parameters[0][0] == \"x\"\n    assert str(c.schema[c.schema_name].function_lists[1].parameters[0][1]) == \"VARCHAR\"\n    assert str(c.schema[c.schema_name].function_lists[1].return_type) == \"VARCHAR\"\n    assert not c.schema[c.schema_name].function_lists[1].aggregation\n\n\ndef test_aggregation_adding():\n    c = Context()\n\n    assert not c.schema[c.schema_name].function_lists\n    assert not c.schema[c.schema_name].functions\n\n    f = lambda x: x\n    c.register_aggregation(f, \"f\", [(\"x\", int)], float)\n\n    assert \"f\" in c.schema[c.schema_name].functions\n    assert c.schema[c.schema_name].functions[\"f\"] == f\n    assert len(c.schema[c.schema_name].function_lists) == 2\n    assert c.schema[c.schema_name].function_lists[0].name == \"F\"\n    assert c.schema[c.schema_name].function_lists[0].parameters[0][0] == \"x\"\n    assert (\n        str(c.schema[c.schema_name].function_lists[0].parameters[0][1])\n        == DEFAULT_INT_TYPE\n    )\n    assert str(c.schema[c.schema_name].function_lists[0].return_type) == \"DOUBLE\"\n    assert c.schema[c.schema_name].function_lists[0].aggregation\n    assert c.schema[c.schema_name].function_lists[1].name == \"f\"\n    assert c.schema[c.schema_name].function_lists[1].parameters[0][0] == \"x\"\n    assert (\n        str(c.schema[c.schema_name].function_lists[1].parameters[0][1])\n        == DEFAULT_INT_TYPE\n    )\n    assert str(c.schema[c.schema_name].function_lists[1].return_type) == \"DOUBLE\"\n    assert c.schema[c.schema_name].function_lists[1].aggregation\n\n    # Without replacement\n    c.register_aggregation(f, \"f\", [(\"x\", float)], int, replace=False)\n\n    assert \"f\" in c.schema[c.schema_name].functions\n    assert c.schema[c.schema_name].functions[\"f\"] == f\n    assert len(c.schema[c.schema_name].function_lists) == 4\n    assert c.schema[c.schema_name].function_lists[2].name == \"F\"\n    assert c.schema[c.schema_name].function_lists[2].parameters[0][0] == \"x\"\n    assert str(c.schema[c.schema_name].function_lists[2].parameters[0][1]) == \"DOUBLE\"\n    assert (\n        str(c.schema[c.schema_name].function_lists[2].return_type) == DEFAULT_INT_TYPE\n    )\n    assert c.schema[c.schema_name].function_lists[2].aggregation\n    assert c.schema[c.schema_name].function_lists[3].name == \"f\"\n    assert c.schema[c.schema_name].function_lists[3].parameters[0][0] == \"x\"\n    assert str(c.schema[c.schema_name].function_lists[3].parameters[0][1]) == \"DOUBLE\"\n    assert (\n        str(c.schema[c.schema_name].function_lists[3].return_type) == DEFAULT_INT_TYPE\n    )\n    assert c.schema[c.schema_name].function_lists[3].aggregation\n\n    # With replacement\n    f = lambda x: x + 1\n    c.register_aggregation(f, \"f\", [(\"x\", str)], str, replace=True)\n\n    assert \"f\" in c.schema[c.schema_name].functions\n    assert c.schema[c.schema_name].functions[\"f\"] == f\n    assert len(c.schema[c.schema_name].function_lists) == 2\n    assert c.schema[c.schema_name].function_lists[0].name == \"F\"\n    assert c.schema[c.schema_name].function_lists[0].parameters[0][0] == \"x\"\n    assert str(c.schema[c.schema_name].function_lists[0].parameters[0][1]) == \"VARCHAR\"\n    assert str(c.schema[c.schema_name].function_lists[0].return_type) == \"VARCHAR\"\n    assert c.schema[c.schema_name].function_lists[0].aggregation\n    assert c.schema[c.schema_name].function_lists[1].name == \"f\"\n    assert c.schema[c.schema_name].function_lists[1].parameters[0][0] == \"x\"\n    assert str(c.schema[c.schema_name].function_lists[1].parameters[0][1]) == \"VARCHAR\"\n    assert str(c.schema[c.schema_name].function_lists[1].return_type) == \"VARCHAR\"\n    assert c.schema[c.schema_name].function_lists[1].aggregation\n\n\ndef test_alter_schema(c):\n    c.create_schema(\"test_schema\")\n    c.sql(\"ALTER SCHEMA test_schema RENAME TO prod_schema\")\n    assert \"prod_schema\" in c.schema\n    assert \"test_schema\" not in c.schema\n\n    with pytest.raises(KeyError):\n        c.sql(\"ALTER SCHEMA MARVEL RENAME TO DC\")\n\n    del c.schema[\"prod_schema\"]\n\n\ndef test_alter_table(c, df_simple):\n    c.create_table(\"maths\", df_simple)\n    c.sql(\"ALTER TABLE maths RENAME TO physics\")\n    assert \"physics\" in c.schema[c.schema_name].tables\n    assert \"maths\" not in c.schema[c.schema_name].tables\n\n    with pytest.raises(KeyError):\n        c.sql(\"ALTER TABLE four_legs RENAME TO two_legs\")\n\n    c.sql(\"ALTER TABLE IF EXISTS alien RENAME TO humans\")\n\n    del c.schema[c.schema_name].tables[\"physics\"]\n\n\ndef test_filepath(tmpdir, parquet_ddf):\n    c = Context()\n    parquet_path = os.path.join(tmpdir, \"parquet\")\n\n    # Create table with string (Parquet filepath)\n    c.create_table(\"parquet_ddf\", parquet_path, format=\"parquet\")\n\n    assert c.schema[\"root\"].tables[\"parquet_ddf\"].filepath == parquet_path\n    assert c.schema[\"root\"].filepaths[\"parquet_ddf\"] == parquet_path\n\n    df = pd.DataFrame({\"a\": [2, 1, 2, 3], \"b\": [3, 3, 1, 3]})\n    c.create_table(\"df\", df)\n\n    assert c.schema[\"root\"].tables[\"df\"].filepath is None\n    with pytest.raises(KeyError):\n        c.schema[\"root\"].filepaths[\"df\"]\n\n\ndef test_ddf_filepath(tmpdir, parquet_ddf):\n    c = Context()\n    parquet_path = os.path.join(tmpdir, \"parquet\")\n\n    # Create table with Dask DataFrame (created from read_parquet)\n    c.create_table(\"parquet_ddf\", parquet_ddf)\n\n    assert c.schema[\"root\"].tables[\"parquet_ddf\"].filepath == parquet_path\n    assert c.schema[\"root\"].filepaths[\"parquet_ddf\"] == parquet_path\n"
  },
  {
    "path": "tests/unit/test_datacontainer.py",
    "content": "from dask_sql.datacontainer import ColumnContainer\n\n\ndef test_cc_init():\n    c = ColumnContainer([\"a\", \"b\", \"c\"])\n\n    assert c.columns == [\"a\", \"b\", \"c\"]\n    assert c.mapping() == [(\"a\", \"a\"), (\"b\", \"b\"), (\"c\", \"c\")]\n\n    c = ColumnContainer([\"a\", \"b\", \"c\"], {\"a\": \"1\", \"b\": \"2\", \"c\": \"3\"})\n\n    assert c.columns == [\"a\", \"b\", \"c\"]\n    assert c.mapping() == [(\"a\", \"1\"), (\"b\", \"2\"), (\"c\", \"3\")]\n\n\ndef test_cc_limit_to():\n    c = ColumnContainer([\"a\", \"b\", \"c\"])\n\n    c2 = c.limit_to([\"c\", \"a\"])\n\n    assert c2.columns == [\"c\", \"a\"]\n    assert c2.mapping() == [(\"a\", \"a\"), (\"b\", \"b\"), (\"c\", \"c\")]\n    assert c.columns == [\"a\", \"b\", \"c\"]\n    assert c.mapping() == [(\"a\", \"a\"), (\"b\", \"b\"), (\"c\", \"c\")]\n\n\ndef test_cc_rename():\n    c = ColumnContainer([\"a\", \"b\", \"c\"])\n\n    c2 = c.rename({\"a\": \"A\", \"b\": \"a\"})\n\n    assert c2.columns == [\"A\", \"a\", \"c\"]\n    assert c2.mapping() == [(\"a\", \"b\"), (\"b\", \"b\"), (\"c\", \"c\"), (\"A\", \"a\")]\n    assert c.columns == [\"a\", \"b\", \"c\"]\n    assert c.mapping() == [(\"a\", \"a\"), (\"b\", \"b\"), (\"c\", \"c\")]\n\n\ndef test_cc_add():\n    c = ColumnContainer([\"a\", \"b\", \"c\"])\n\n    c2 = c.add(\"d\")\n\n    assert c2.columns == [\"a\", \"b\", \"c\", \"d\"]\n    assert c2.mapping() == [(\"a\", \"a\"), (\"b\", \"b\"), (\"c\", \"c\"), (\"d\", \"d\")]\n    assert c.columns == [\"a\", \"b\", \"c\"]\n    assert c.mapping() == [(\"a\", \"a\"), (\"b\", \"b\"), (\"c\", \"c\")]\n\n    c2 = c.add(\"d\", \"D\")\n\n    assert c2.columns == [\"a\", \"b\", \"c\", \"d\"]\n    assert c2.mapping() == [(\"a\", \"a\"), (\"b\", \"b\"), (\"c\", \"c\"), (\"d\", \"D\")]\n    assert c.columns == [\"a\", \"b\", \"c\"]\n    assert c.mapping() == [(\"a\", \"a\"), (\"b\", \"b\"), (\"c\", \"c\")]\n\n    c2 = c.add(\"d\", \"a\")\n\n    assert c2.columns == [\"a\", \"b\", \"c\", \"d\"]\n    assert c2.mapping() == [(\"a\", \"a\"), (\"b\", \"b\"), (\"c\", \"c\"), (\"d\", \"a\")]\n    assert c.columns == [\"a\", \"b\", \"c\"]\n    assert c.mapping() == [(\"a\", \"a\"), (\"b\", \"b\"), (\"c\", \"c\")]\n\n    c2 = c.add(\"a\", \"b\")\n\n    assert c2.columns == [\"a\", \"b\", \"c\"]\n    assert c2.mapping() == [(\"a\", \"b\"), (\"b\", \"b\"), (\"c\", \"c\")]\n    assert c.columns == [\"a\", \"b\", \"c\"]\n    assert c.mapping() == [(\"a\", \"a\"), (\"b\", \"b\"), (\"c\", \"c\")]\n"
  },
  {
    "path": "tests/unit/test_mapping.py",
    "content": "from datetime import timedelta\n\nimport numpy as np\nimport pandas as pd\nimport pytest\n\nfrom dask_sql._datafusion_lib import SqlTypeName\nfrom dask_sql.mappings import python_to_sql_type, similar_type, sql_to_python_value\n\n\ndef test_python_to_sql():\n    assert str(python_to_sql_type(np.dtype(\"int32\"))) == \"INTEGER\"\n    assert str(python_to_sql_type(np.dtype(\">M8[ns]\"))) == \"TIMESTAMP\"\n    assert (\n        str(python_to_sql_type(pd.DatetimeTZDtype(unit=\"ns\", tz=\"UTC\")))\n        == \"TIMESTAMP_WITH_LOCAL_TIME_ZONE\"\n    )\n\n\n@pytest.mark.gpu\ndef test_python_decimal_to_sql():\n    import cudf\n\n    assert str(python_to_sql_type(cudf.Decimal64Dtype(12, 3))) == \"DECIMAL\"\n    assert str(python_to_sql_type(cudf.Decimal128Dtype(32, 12))) == \"DECIMAL\"\n    assert str(python_to_sql_type(cudf.Decimal32Dtype(5, -2))) == \"DECIMAL\"\n\n\ndef test_sql_to_python():\n    assert sql_to_python_value(SqlTypeName.VARCHAR, \"test 123\") == \"test 123\"\n    assert type(sql_to_python_value(SqlTypeName.BIGINT, 653)) == np.int64\n    assert sql_to_python_value(SqlTypeName.BIGINT, 653) == 653\n    assert sql_to_python_value(SqlTypeName.INTERVAL, 4) == timedelta(microseconds=4000)\n\n\ndef test_python_to_sql_to_python():\n    assert (\n        type(\n            sql_to_python_value(python_to_sql_type(np.dtype(\"int64\")).getSqlType(), 54)\n        )\n        == np.int64\n    )\n\n\ndef test_similar_type():\n    assert similar_type(np.int64, np.int32)\n    assert similar_type(pd.Int64Dtype(), np.int32)\n    assert not similar_type(np.uint32, np.int32)\n    assert similar_type(np.float32, np.float64)\n    assert similar_type(object, str)\n"
  },
  {
    "path": "tests/unit/test_ml_utils.py",
    "content": "# Copyright 2017, Dask developers\n# Dask-ML project - https://github.com/dask/dask-ml\nfrom collections.abc import Sequence\n\nimport dask\nimport dask.array as da\nimport dask.dataframe as dd\nimport numpy as np\nimport pandas as pd\nimport pytest\nfrom dask.array.utils import assert_eq as assert_eq_ar\nfrom dask.dataframe.utils import assert_eq as assert_eq_df\nfrom sklearn.base import clone\nfrom sklearn.decomposition import PCA\nfrom sklearn.ensemble import GradientBoostingClassifier\nfrom sklearn.linear_model import LogisticRegression, SGDClassifier\n\nfrom dask_sql.physical.rel.custom.wrappers import Incremental, ParallelPostFit\n\n\n@pytest.mark.parametrize(\"gpu\", [False, pytest.param(True, marks=pytest.mark.gpu)])\ndef test_ml_class_mappings(gpu):\n    from dask_sql.physical.utils.ml_classes import get_cpu_classes, get_gpu_classes\n    from dask_sql.utils import import_class\n\n    try:\n        import lightgbm\n        import xgboost\n    except KeyError:\n        lightgbm = None\n        xgboost = None\n\n    classes_dict = get_gpu_classes() if gpu else get_cpu_classes()\n\n    for key in classes_dict:\n        if not (\"XGB\" in key and xgboost is None) and not (\n            \"LGBM\" in key and lightgbm is None\n        ):\n            import_class(classes_dict[key])\n\n\ndef _check_axis_partitioning(chunks, n_features):\n    c = chunks[1][0]\n    if c != n_features:\n        msg = (\n            \"Can only generate arrays partitioned along the \"\n            \"first axis. Specifying a larger chunksize for \"\n            \"the second axis.\\n\\n\\tchunk size: {}\\n\"\n            \"\\tn_features: {}\".format(c, n_features)\n        )\n        raise ValueError(msg)\n\n\ndef check_random_state(random_state):\n    if random_state is None:\n        return da.random.RandomState()\n    # elif isinstance(random_state, Integral):\n    #     return da.random.RandomState(random_state)\n    elif isinstance(random_state, np.random.RandomState):\n        return da.random.RandomState(random_state.randint())\n    elif isinstance(random_state, da.random.RandomState):\n        return random_state\n    else:\n        raise TypeError(f\"Unexpected type '{type(random_state)}'\")\n\n\ndef make_classification(\n    n_samples=100,\n    n_features=20,\n    n_informative=2,\n    n_classes=2,\n    scale=1.0,\n    random_state=None,\n    chunks=None,\n):\n    chunks = da.core.normalize_chunks(chunks, (n_samples, n_features))\n    _check_axis_partitioning(chunks, n_features)\n\n    if n_classes != 2:\n        raise NotImplementedError(\"n_classes != 2 is not yet supported.\")\n\n    rng = check_random_state(random_state)\n\n    X = rng.normal(0, 1, size=(n_samples, n_features), chunks=chunks)\n    informative_idx = rng.choice(n_features, n_informative, chunks=n_informative)\n    beta = (rng.random(n_features, chunks=n_features) - 1) * scale\n\n    informative_idx, beta = dask.compute(\n        informative_idx, beta, scheduler=\"single-threaded\"\n    )\n\n    z0 = X[:, informative_idx].dot(beta[informative_idx])\n    y = rng.random(z0.shape, chunks=chunks[0]) < 1 / (1 + da.exp(-z0))\n    y = y.astype(int)\n\n    return X, y\n\n\ndef _assert_eq(l, r, name=None, **kwargs):\n    array_types = (np.ndarray, da.Array)\n    frame_types = (pd.core.generic.NDFrame, dd.DataFrame)\n    if isinstance(l, array_types):\n        assert_eq_ar(l, r, **kwargs)\n    elif isinstance(l, frame_types):\n        assert_eq_df(l, r, **kwargs)\n    elif isinstance(l, Sequence) and any(\n        isinstance(x, array_types + frame_types) for x in l\n    ):\n        for a, b in zip(l, r):\n            _assert_eq(a, b, **kwargs)\n    elif np.isscalar(r) and np.isnan(r):\n        assert np.isnan(l), (name, l, r)\n    else:\n        assert l == r, (name, l, r)\n\n\ndef assert_estimator_equal(left, right, exclude=None, **kwargs):\n    \"\"\"Check that two Estimators are equal\n    Parameters\n    ----------\n    left, right : Estimators\n    exclude : str or sequence of str\n        attributes to skip in the check\n    kwargs : dict\n        Passed through to the dask `assert_eq` method.\n    \"\"\"\n    left_attrs = [x for x in dir(left) if x.endswith(\"_\") and not x.startswith(\"_\")]\n    right_attrs = [x for x in dir(right) if x.endswith(\"_\") and not x.startswith(\"_\")]\n    if exclude is None:\n        exclude = set()\n    elif isinstance(exclude, str):\n        exclude = {exclude}\n    else:\n        exclude = set(exclude)\n\n    left_attrs2 = set(left_attrs) - exclude\n    right_attrs2 = set(right_attrs) - exclude\n\n    assert left_attrs2 == right_attrs2, left_attrs2 ^ right_attrs2\n\n    for attr in left_attrs2:\n        l = getattr(left, attr)\n        r = getattr(right, attr)\n        _assert_eq(l, r, name=attr, **kwargs)\n\n\ndef test_parallelpostfit_basic():\n    clf = ParallelPostFit(GradientBoostingClassifier())\n\n    X, y = make_classification(n_samples=1000, chunks=100)\n    X_, y_ = dask.compute(X, y)\n    clf.fit(X_, y_)\n\n    assert isinstance(clf.predict(X), da.Array)\n    assert isinstance(clf.predict_proba(X), da.Array)\n\n    result = clf.score(X, y)\n    expected = clf.estimator.score(X_, y_)\n    assert result == expected\n\n\n@pytest.mark.parametrize(\"kind\", [\"numpy\", \"dask.dataframe\", \"dask.array\"])\ndef test_predict(kind):\n    X, y = make_classification(chunks=100)\n\n    if kind == \"numpy\":\n        X, y = dask.compute(X, y)\n    elif kind == \"dask.dataframe\":\n        X = dd.from_dask_array(X)\n        y = dd.from_dask_array(y)\n\n    base = LogisticRegression(random_state=0, n_jobs=1, solver=\"lbfgs\")\n    wrap = ParallelPostFit(\n        LogisticRegression(random_state=0, n_jobs=1, solver=\"lbfgs\"),\n    )\n\n    base.fit(*dask.compute(X, y))\n    wrap.fit(*dask.compute(X, y))\n\n    assert_estimator_equal(wrap.estimator, base)\n\n    result = wrap.predict(X)\n    expected = base.predict(X)\n    assert_eq_ar(result, expected)\n\n    result = wrap.predict_proba(X)\n    expected = base.predict_proba(X)\n    assert_eq_ar(result, expected)\n\n    result = wrap.predict_log_proba(X)\n    expected = base.predict_log_proba(X)\n    assert_eq_ar(result, expected)\n\n\n@pytest.mark.parametrize(\"kind\", [\"numpy\", \"dask.dataframe\", \"dask.array\"])\ndef test_transform(kind):\n    X, y = make_classification(chunks=100)\n\n    if kind == \"numpy\":\n        X, y = dask.compute(X, y)\n    elif kind == \"dask.dataframe\":\n        X = dd.from_dask_array(X)\n        y = dd.from_dask_array(y)\n\n    base = PCA(random_state=0)\n    wrap = ParallelPostFit(PCA(random_state=0))\n\n    base.fit(*dask.compute(X, y))\n    wrap.fit(*dask.compute(X, y))\n\n    assert_estimator_equal(wrap.estimator, base)\n\n    result = base.transform(*dask.compute(X))\n    expected = wrap.transform(X)\n    assert_eq_ar(result, expected)\n\n\n@pytest.mark.parametrize(\"dataframes\", [False, True])\ndef test_incremental_basic(dataframes):\n    # Create observations that we know linear models can recover\n    n, d = 100, 3\n    rng = da.random.RandomState(42)\n    X = rng.normal(size=(n, d), chunks=30)\n    coef_star = rng.uniform(size=d, chunks=d)\n    y = da.sign(X.dot(coef_star))\n    y = (y + 1) / 2\n    if dataframes:\n        X = dd.from_array(X)\n        y = dd.from_array(y)\n\n    est1 = SGDClassifier(random_state=0, tol=1e-3, average=True)\n    est2 = clone(est1)\n\n    clf = Incremental(est1, random_state=0)\n    result = clf.fit(X, y, classes=[0, 1])\n    assert result is clf\n\n    # est2 is a sklearn optimizer; this is just a benchmark\n    if dataframes:\n        X = X.to_dask_array(lengths=True)\n        y = y.to_dask_array(lengths=True)\n\n    for slice_ in da.core.slices_from_chunks(X.chunks):\n        est2.partial_fit(X[slice_].compute(), y[slice_[0]].compute(), classes=[0, 1])\n\n    assert isinstance(result.estimator_.coef_, np.ndarray)\n    rel_error = np.linalg.norm(clf.coef_ - est2.coef_)\n    rel_error /= np.linalg.norm(clf.coef_)\n    assert rel_error < 0.9\n\n    assert set(dir(clf.estimator_)) == set(dir(est2))\n\n    #  Predict\n    result = clf.predict(X)\n    expected = est2.predict(X)\n    assert isinstance(result, da.Array)\n    if dataframes:\n        # Compute is needed because chunk sizes of this array are unknown\n        result = result.compute()\n    rel_error = np.linalg.norm(result - expected)\n    rel_error /= np.linalg.norm(expected)\n    assert rel_error < 0.3\n\n    # score\n    result = clf.score(X, y)\n    expected = est2.score(*dask.compute(X, y))\n    assert abs(result - expected) < 0.1\n\n    clf = Incremental(SGDClassifier(random_state=0, tol=1e-3, average=True))\n    clf.partial_fit(X, y, classes=[0, 1])\n    assert set(dir(clf.estimator_)) == set(dir(est2))\n"
  },
  {
    "path": "tests/unit/test_queries.py",
    "content": "import os\n\nimport pytest\n\nXFAIL_QUERIES = (\n    5,\n    8,\n    10,\n    14,\n    16,\n    18,\n    22,\n    23,\n    24,\n    27,\n    28,\n    35,\n    36,\n    39,\n    41,\n    44,\n    47,\n    49,\n    51,\n    57,\n    62,\n    64,  # FIXME: failing after cudf#14167 and #14079\n    67,\n    69,\n    70,\n    72,\n    77,\n    80,\n    86,\n    88,\n    89,\n    92,\n    94,\n    99,\n)\n\nQUERIES = [\n    pytest.param(f\"q{i}.sql\", marks=pytest.mark.xfail if i in XFAIL_QUERIES else ())\n    for i in range(1, 100)\n]\n\n\n@pytest.fixture(scope=\"module\")\ndef c(data_dir):\n    # Lazy import, otherwise the pytest framework has problems\n    from dask_sql.context import Context\n\n    c = Context()\n    if not data_dir:\n        data_dir = f\"{os.path.dirname(__file__)}/data/\"\n    for table_name in os.listdir(data_dir):\n        c.create_table(\n            table_name,\n            data_dir + \"/\" + table_name,\n            format=\"parquet\",\n            gpu=False,\n        )\n\n    yield c\n\n\n@pytest.fixture(scope=\"module\")\ndef gpu_c(data_dir):\n    pytest.importorskip(\"dask_cudf\")\n\n    # Lazy import, otherwise the pytest framework has problems\n    from dask_sql.context import Context\n\n    c = Context()\n    if not data_dir:\n        data_dir = f\"{os.path.dirname(__file__)}/data/\"\n    for table_name in os.listdir(data_dir):\n        c.create_table(\n            table_name,\n            data_dir + \"/\" + table_name,\n            format=\"parquet\",\n            gpu=True,\n        )\n\n    yield c\n\n\n@pytest.mark.queries\n@pytest.mark.parametrize(\"query\", QUERIES)\ndef test_query(c, client, query, queries_dir):\n    if not queries_dir:\n        queries_dir = f\"{os.path.dirname(__file__)}/queries/\"\n    with open(queries_dir + \"/\" + query) as f:\n        sql = f.read()\n\n    res = c.sql(sql)\n    res.compute(scheduler=client)\n\n\n@pytest.mark.gpu\n@pytest.mark.queries\n@pytest.mark.parametrize(\"query\", QUERIES)\ndef test_gpu_query(gpu_c, gpu_client, query, queries_dir):\n    if not queries_dir:\n        queries_dir = f\"{os.path.dirname(__file__)}/queries/\"\n    with open(queries_dir + \"/\" + query) as f:\n        sql = f.read()\n\n    res = gpu_c.sql(sql)\n    res.compute(scheduler=gpu_client)\n"
  },
  {
    "path": "tests/unit/test_statistics.py",
    "content": "import dask.dataframe as dd\nimport pandas as pd\nimport pytest\n\nfrom dask_sql import Context\nfrom dask_sql.datacontainer import Statistics\nfrom dask_sql.physical.utils.statistics import parquet_statistics\nfrom tests.utils import skipif_dask_expr_enabled\n\n# TODO: add support for parquet statistics with dask-expr\npytestmark = skipif_dask_expr_enabled(\n    reason=\"Parquet statistics not yet supported with dask-expr\"\n)\n\n\n@pytest.mark.parametrize(\"parallel\", [None, False, 2])\ndef test_parquet_statistics(parquet_ddf, parallel):\n\n    # Check simple num-rows statistics\n    stats = parquet_statistics(parquet_ddf, parallel=parallel)\n    stats_df = pd.DataFrame(stats)\n    num_rows = stats_df[\"num-rows\"].sum()\n    assert len(stats_df) == parquet_ddf.npartitions\n    assert num_rows == len(parquet_ddf)\n\n    # Check simple column statistics\n    stats = parquet_statistics(parquet_ddf, columns=[\"b\"], parallel=parallel)\n    b_stats = [\n        {\n            \"min\": stat[\"columns\"][0][\"min\"],\n            \"max\": stat[\"columns\"][0][\"max\"],\n        }\n        for stat in stats\n    ]\n    b_stats_df = pd.DataFrame(b_stats)\n    assert b_stats_df[\"min\"].min() == parquet_ddf[\"b\"].min().compute()\n    assert b_stats_df[\"max\"].max() == parquet_ddf[\"b\"].max().compute()\n\n\ndef test_parquet_statistics_bad_args(parquet_ddf):\n    # Check \"bad\" input arguments to parquet_statistics\n\n    # ddf argument must be a Dask-DataFrame object\n    pdf = pd.DataFrame({\"a\": range(10)})\n    with pytest.raises(ValueError, match=\"Expected Dask DataFrame\"):\n        parquet_statistics(pdf)\n\n    # Return should be None if parquet statistics\n    # cannot be extracted from the provided collection\n    ddf = dd.from_pandas(pdf, npartitions=2)\n    assert parquet_statistics(ddf) is None\n\n    # Clear error should be raised when columns is not\n    # a list containing a subset of columns from ddf\n    with pytest.raises(ValueError, match=\"Expected columns to be a list\"):\n        parquet_statistics(parquet_ddf, columns=\"bad\")\n\n    with pytest.raises(ValueError, match=\"must be a subset\"):\n        parquet_statistics(parquet_ddf, columns=[\"bad\"])\n\n\ndef test_dc_statistics(parquet_ddf):\n    c = Context()\n    c.create_table(\"df\", parquet_ddf)\n\n    assert c.schema[\"root\"].tables[\"df\"].statistics == Statistics(row_count=15)\n    assert c.schema[\"root\"].statistics[\"df\"] == Statistics(row_count=15)\n"
  },
  {
    "path": "tests/unit/test_utils.py",
    "content": "import pandas as pd\nimport pytest\nfrom dask import dataframe as dd\nfrom dask.utils_test import hlg_layer\n\nfrom dask_sql.physical.utils.filter import attempt_predicate_pushdown\nfrom dask_sql.utils import Pluggable, is_frame\nfrom tests.utils import skipif_dask_expr_enabled\n\n\ndef test_is_frame_for_frame():\n    df = dd.from_pandas(pd.DataFrame({\"a\": [1]}), npartitions=1)\n    assert is_frame(df)\n\n\ndef test_is_frame_for_none():\n    assert not is_frame(None)\n\n\ndef test_is_frame_for_number():\n    assert not is_frame(3)\n    assert not is_frame(3.5)\n\n\nclass PluginTest1(Pluggable):\n    pass\n\n\nclass PluginTest2(Pluggable):\n    pass\n\n\ndef test_add_plugin():\n    PluginTest1.add_plugin(\"some_key\", \"value\")\n\n    assert PluginTest1.get_plugin(\"some_key\") == \"value\"\n    assert PluginTest1().get_plugin(\"some_key\") == \"value\"\n\n    with pytest.raises(KeyError):\n        PluginTest2.get_plugin(\"some_key\")\n\n\ndef test_overwrite():\n    PluginTest1.add_plugin(\"some_key\", \"value\")\n\n    assert PluginTest1.get_plugin(\"some_key\") == \"value\"\n    assert PluginTest1().get_plugin(\"some_key\") == \"value\"\n\n    PluginTest1.add_plugin(\"some_key\", \"value_2\")\n\n    assert PluginTest1.get_plugin(\"some_key\") == \"value_2\"\n    assert PluginTest1().get_plugin(\"some_key\") == \"value_2\"\n\n    PluginTest1.add_plugin(\"some_key\", \"value_3\", replace=False)\n\n    assert PluginTest1.get_plugin(\"some_key\") == \"value_2\"\n    assert PluginTest1().get_plugin(\"some_key\") == \"value_2\"\n\n\n@skipif_dask_expr_enabled()\ndef test_predicate_pushdown_simple(parquet_ddf):\n    filtered_df = parquet_ddf[parquet_ddf[\"a\"] > 1]\n    pushdown_df = attempt_predicate_pushdown(filtered_df)\n    got_filters = hlg_layer(pushdown_df.dask, \"read-parquet\").creation_info[\"kwargs\"][\n        \"filters\"\n    ]\n    got_filters = frozenset(frozenset(v) for v in got_filters)\n    expected_filters = [[(\"a\", \">\", 1)]]\n    expected_filters = frozenset(frozenset(v) for v in expected_filters)\n    assert got_filters == expected_filters\n\n\n@skipif_dask_expr_enabled()\ndef test_predicate_pushdown_logical(parquet_ddf):\n    filtered_df = parquet_ddf[\n        (parquet_ddf[\"a\"] > 1) & (parquet_ddf[\"b\"] < 2) | (parquet_ddf[\"a\"] == -1)\n    ]\n\n    pushdown_df = attempt_predicate_pushdown(filtered_df)\n    got_filters = hlg_layer(pushdown_df.dask, \"read-parquet\").creation_info[\"kwargs\"][\n        \"filters\"\n    ]\n    got_filters = frozenset(frozenset(v) for v in got_filters)\n    expected_filters = [[(\"a\", \">\", 1), (\"b\", \"<\", 2)], [(\"a\", \"==\", -1)]]\n    expected_filters = frozenset(frozenset(v) for v in expected_filters)\n    assert got_filters == expected_filters\n\n\n@skipif_dask_expr_enabled()\ndef test_predicate_pushdown_in(parquet_ddf):\n    filtered_df = parquet_ddf[\n        (parquet_ddf[\"a\"] > 1) & (parquet_ddf[\"b\"] < 2)\n        | (parquet_ddf[\"a\"] == -1) & parquet_ddf[\"c\"].isin((\"A\", \"B\", \"C\"))\n        | ~parquet_ddf[\"b\"].isin((5, 6, 7))\n    ]\n    pushdown_df = attempt_predicate_pushdown(filtered_df)\n    got_filters = hlg_layer(pushdown_df.dask, \"read-parquet\").creation_info[\"kwargs\"][\n        \"filters\"\n    ]\n    got_filters = frozenset(frozenset(v) for v in got_filters)\n    expected_filters = [\n        [(\"b\", \"<\", 2), (\"a\", \">\", 1)],\n        [(\"a\", \"==\", -1), (\"c\", \"in\", (\"A\", \"B\", \"C\"))],\n        [(\"b\", \"not in\", (5, 6, 7))],\n    ]\n    expected_filters = frozenset(frozenset(v) for v in expected_filters)\n    assert got_filters == expected_filters\n\n\n@skipif_dask_expr_enabled()\ndef test_predicate_pushdown_isna(parquet_ddf):\n    filtered_df = parquet_ddf[\n        (parquet_ddf[\"a\"] > 1) & (parquet_ddf[\"b\"] < 2)\n        | (parquet_ddf[\"a\"] == -1) & ~parquet_ddf[\"c\"].isna()\n        | parquet_ddf[\"b\"].isna()\n    ]\n    pushdown_df = attempt_predicate_pushdown(filtered_df)\n    got_filters = hlg_layer(pushdown_df.dask, \"read-parquet\").creation_info[\"kwargs\"][\n        \"filters\"\n    ]\n    got_filters = frozenset(frozenset(v) for v in got_filters)\n    expected_filters = [\n        [(\"b\", \"<\", 2), (\"a\", \">\", 1)],\n        [(\"a\", \"==\", -1), (\"c\", \"is not\", None)],\n        [(\"b\", \"is\", None)],\n    ]\n    expected_filters = frozenset(frozenset(v) for v in expected_filters)\n    assert got_filters == expected_filters\n\n\n@skipif_dask_expr_enabled()\ndef test_predicate_pushdown_add_filters(parquet_ddf):\n    filtered_df = parquet_ddf[(parquet_ddf[\"a\"] > 1) | (parquet_ddf[\"a\"] == -1)]\n    pushdown_df = attempt_predicate_pushdown(\n        filtered_df,\n        add_filters=(\"b\", \"<\", 2),\n    )\n    got_filters = hlg_layer(pushdown_df.dask, \"read-parquet\").creation_info[\"kwargs\"][\n        \"filters\"\n    ]\n    got_filters = frozenset(frozenset(v) for v in got_filters)\n    expected_filters = [\n        [(\"a\", \">\", 1), (\"b\", \"<\", 2)],\n        [(\"a\", \"==\", -1), (\"b\", \"<\", 2)],\n    ]\n    expected_filters = frozenset(frozenset(v) for v in expected_filters)\n    assert got_filters == expected_filters\n\n\n@skipif_dask_expr_enabled()\ndef test_predicate_pushdown_add_filters_no_extract(parquet_ddf):\n    filtered_df = parquet_ddf[(parquet_ddf[\"a\"] > 1) | (parquet_ddf[\"a\"] == -1)]\n    pushdown_df = attempt_predicate_pushdown(\n        filtered_df,\n        extract_filters=False,\n        add_filters=(\"b\", \"<\", 2),\n    )\n    got_filters = hlg_layer(pushdown_df.dask, \"read-parquet\").creation_info[\"kwargs\"][\n        \"filters\"\n    ]\n    got_filters = frozenset(frozenset(v) for v in got_filters)\n    expected_filters = [[(\"b\", \"<\", 2)]]\n    expected_filters = frozenset(frozenset(v) for v in expected_filters)\n    assert got_filters == expected_filters\n\n\n@skipif_dask_expr_enabled()\ndef test_predicate_pushdown_add_filters_no_preserve(parquet_ddf):\n    filtered_df = parquet_ddf[(parquet_ddf[\"a\"] > 1) | (parquet_ddf[\"a\"] == -1)]\n    pushdown_df0 = attempt_predicate_pushdown(filtered_df)\n    pushdown_df = attempt_predicate_pushdown(\n        pushdown_df0,\n        preserve_filters=False,\n        extract_filters=False,\n        add_filters=(\"b\", \"<\", 2),\n    )\n\n    got_filters = hlg_layer(pushdown_df.dask, \"read-parquet\").creation_info[\"kwargs\"][\n        \"filters\"\n    ]\n    got_filters = frozenset(frozenset(v) for v in got_filters)\n    expected_filters = [[(\"b\", \"<\", 2)]]\n    expected_filters = frozenset(frozenset(v) for v in expected_filters)\n    assert got_filters == expected_filters\n"
  },
  {
    "path": "tests/utils.py",
    "content": "import os\n\nimport pytest\nfrom dask.dataframe import _dask_expr_enabled\nfrom dask.dataframe.utils import assert_eq as _assert_eq\n\n# use distributed client for testing if it's available\nscheduler = (\n    \"distributed\"\n    if os.getenv(\"DASK_SQL_DISTRIBUTED_TESTS\", \"False\").lower() in (\"true\", \"1\")\n    else \"sync\"\n)\n\n\ndef assert_eq(*args, **kwargs):\n    kwargs.setdefault(\"scheduler\", scheduler)\n\n    return _assert_eq(*args, **kwargs)\n\n\ndef convert_nullable_columns(df):\n    \"\"\"\n    Convert certain nullable columns in `df` to non-nullable columns\n    when trying to handle np.NaN and pd.NA would otherwise cause issues.\n    \"\"\"\n    dtypes_mapping = {\n        \"Int64\": \"float64\",\n        \"Float64\": \"float64\",\n        \"boolean\": \"float64\",\n    }\n\n    for dtype in dtypes_mapping:\n        selected_cols = df.select_dtypes(include=[dtype]).columns.tolist()\n        if selected_cols:\n            df[selected_cols] = df[selected_cols].astype(dtypes_mapping[dtype])\n\n    return df\n\n\ndef skipif_dask_expr_enabled(reason=None):\n    \"\"\"\n    Skip the test if dask-expr is enabled\n    \"\"\"\n    # most common reason for skipping\n    if reason is None:\n        reason = \"Predicate pushdown & column projection should be handled implicitly by dask-expr\"\n\n    return pytest.mark.skipif(\n        _dask_expr_enabled(),\n        reason=reason,\n    )\n"
  }
]